1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/quantized/cpu/embeddingxb.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <algorithm>
12 #include <cassert>
13 #include <cinttypes>
14 #include <cmath>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21 using Scalar = exec_aten::Scalar;
22 using ScalarType = exec_aten::ScalarType;
23
24 namespace {
25
26 static inline int32_t
weight_value(const unsigned char * w_data,int32_t index,int32_t weight_nbit)27 weight_value(const unsigned char* w_data, int32_t index, int32_t weight_nbit) {
28 if (weight_nbit == 2) {
29 int32_t subbyte = index % 4;
30 index >>= 2;
31 switch (subbyte) {
32 case 0:
33 return (int32_t)(w_data[index] & 3) - 2;
34 case 1:
35 return (int32_t)((w_data[index] & 12) >> 2) - 2;
36 case 2:
37 return (int32_t)((w_data[index] & 48) >> 4) - 2;
38 case 3:
39 return (int32_t)((w_data[index] & 192) >> 6) - 2;
40 }
41 } else if (weight_nbit == 4) {
42 int32_t odd = index & 1;
43 index >>= 1;
44 if (odd) {
45 return (int32_t)(w_data[index] & 0x0F) - 8;
46 } else {
47 return (int32_t)((w_data[index] >> 4) & 0x0F) - 8;
48 }
49 }
50
51 ET_CHECK_MSG(false, "invalid weight_nbit");
52 }
53
get_embedding_dim(int32_t packed_dim,int32_t weight_nbit)54 static inline int32_t get_embedding_dim(
55 int32_t packed_dim,
56 int32_t weight_nbit) {
57 ET_CHECK_MSG(8 % weight_nbit == 0, "invalid embedding dim");
58 int packed_values_per_byte = 8 / weight_nbit;
59 return packed_dim * packed_values_per_byte;
60 }
61
62 /**
63 * Asserts that the parameters are valid.
64 */
check_embedding_xbit_args(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)65 void check_embedding_xbit_args(
66 const Tensor& weight,
67 const Tensor& weight_scales,
68 const exec_aten::optional<Tensor>& opt_weight_zero_points,
69 const int64_t weight_quant_min,
70 const int64_t weight_quant_max,
71 const Tensor& indices,
72 exec_aten::optional<ScalarType> out_dtype,
73 Tensor& out,
74 int weight_nbit) {
75 ET_CHECK_MSG(8 % weight_nbit == 0, "nbit must divide 8");
76
77 ET_CHECK_MSG(
78 weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
79
80 ET_CHECK_MSG(
81 weight_scales.dim() == 1 || weight_scales.dim() == 2,
82 "weight_scales must be 1D or 2D but got() %zd dims",
83 weight_scales.dim());
84
85 ET_CHECK_MSG(
86 weight_scales.size(0) == weight.size(0),
87 "Number of scales must be == weight.size(0)=%zd"
88 ", but got %zd",
89 weight_scales.size(0),
90 weight.size(0));
91
92 if (weight_scales.dim() == 2) {
93 auto num_groups = weight_scales.size(1);
94 ET_CHECK_MSG(
95 // each 8b uint8 column is packed_values_per_byte columns
96 get_embedding_dim(weight.size(1), weight_nbit) % num_groups == 0,
97 "Number of groups must divide weight.size(1)=%zd"
98 ", but got # of groups = %zd",
99 weight.size(1),
100 num_groups);
101 }
102
103 ET_CHECK_MSG(
104 weight.scalar_type() == ScalarType::Byte,
105 "weight.scalar_type() %" PRId8 " is not supported:",
106 static_cast<int8_t>(weight.scalar_type()));
107
108 ET_CHECK_MSG(
109 out.scalar_type() == ScalarType::Float ||
110 out.scalar_type() == ScalarType::Half,
111 "out.scalar_type() %" PRId8 " is not supported:",
112 static_cast<int8_t>(out.scalar_type()));
113
114 ET_CHECK_MSG(
115 weight_scales.scalar_type() == ScalarType::Float ||
116 weight_scales.scalar_type() == ScalarType::Half,
117 "weight_scales.scalar_type() %" PRId8 " is not supported:",
118 static_cast<int8_t>(weight_scales.scalar_type()));
119
120 if (opt_weight_zero_points.has_value()) {
121 ET_CHECK_MSG(
122 opt_weight_zero_points.value().dim() == weight_scales.dim(),
123 "weight_zero_points's rank match that of weight_scales. "
124 "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
125 static_cast<int8_t>(opt_weight_zero_points.value().dim()),
126 static_cast<int8_t>(weight_scales.dim()));
127
128 ET_CHECK_MSG(
129 opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
130 "weight zero points scalar type %" PRId8
131 " does not match out.scalar_type()",
132 static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
133
134 for (int32_t i = 0; i < weight_scales.dim(); ++i) {
135 ET_CHECK_MSG(
136 opt_weight_zero_points.value().size(i) == weight_scales.size(i),
137 "Dimension size misatch at dim %" PRIi32
138 "Weight_zero_point size = %zd"
139 ", weight_scales size = %zd.",
140 i,
141 opt_weight_zero_points.value().size(i),
142 weight_scales.size(i));
143 }
144 }
145
146 ET_CHECK_MSG(
147 indices.scalar_type() == ScalarType::Long,
148 "indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
149 static_cast<int8_t>(indices.scalar_type()));
150
151 ET_CHECK_MSG(
152 weight_quant_min <= weight_quant_max,
153 "weight quant min: %" PRId64
154 " is greater than weight quant max: %" PRId64,
155 weight_quant_min,
156 weight_quant_max);
157
158 if (out_dtype.has_value()) {
159 ET_CHECK_MSG(
160 out.scalar_type() == out_dtype.value(),
161 "output_dtype must match the dtype of the out tensor");
162 }
163 }
164
165 /**
166 * Retrieves the embeddings specified by indices, dequantizes them, and stores
167 * them in out. Weight will always be uint8
168 */
169 template <typename CTYPE_PARAMS, typename CTYPE_OUT>
embedding_xbit_per_channel(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const Tensor & indices,Tensor & out,int weight_nbit)170 void embedding_xbit_per_channel(
171 const Tensor& weight,
172 const Tensor& weight_scales,
173 const exec_aten::optional<Tensor>& opt_weight_zero_points,
174 const Tensor& indices,
175 Tensor& out,
176 int weight_nbit) {
177 auto embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
178
179 int32_t num_groups_per_channel = 1;
180 if (weight_scales.dim() == 2) {
181 num_groups_per_channel = weight_scales.size(1);
182 }
183 int32_t group_size = embedding_dim / num_groups_per_channel;
184
185 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
186 const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
187
188 const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
189 const CTYPE_PARAMS* zero_points = nullptr;
190 if (opt_weight_zero_points.has_value()) {
191 zero_points = opt_weight_zero_points.value().const_data_ptr<CTYPE_PARAMS>();
192 }
193
194 for (int i = 0; i < indices.numel(); i++) {
195 int64_t index = indices_ptr[i];
196 // If using groupwise embedding
197 int32_t qparams_index = index * num_groups_per_channel;
198 CTYPE_PARAMS zp = 0.0;
199 const CTYPE_PARAMS* scale_ptr = scales + qparams_index;
200 const CTYPE_PARAMS* zero_points_ptr = nullptr;
201 if (opt_weight_zero_points.has_value()) {
202 zero_points_ptr = zero_points + qparams_index;
203 }
204
205 const uint8_t* w_data =
206 weight.const_data_ptr<uint8_t>() + weight.size(1) * index;
207
208 for (int j = 0; j < embedding_dim; ++j) {
209 int32_t group_id = j / group_size;
210 const CTYPE_PARAMS scale = scale_ptr[group_id];
211 if (opt_weight_zero_points.has_value()) {
212 zp = zero_points_ptr[group_id];
213 }
214 out_data[j] = static_cast<CTYPE_OUT>(
215 (static_cast<float>(weight_value(w_data, j, weight_nbit)) -
216 static_cast<float>(zp)) *
217 static_cast<float>(scale));
218 }
219 out_data += embedding_dim;
220 }
221 }
222
resize_out_tensor(const Tensor & weight,const Tensor & indices,Tensor & out,int weight_nbit)223 void resize_out_tensor(
224 const Tensor& weight,
225 const Tensor& indices,
226 Tensor& out,
227 int weight_nbit) {
228 exec_aten::SizesType expected_output_size[kTensorDimensionLimit];
229 for (size_t i = 0; i < indices.dim(); i++) {
230 expected_output_size[i] = indices.size(i);
231 }
232 const size_t embedding_dim = get_embedding_dim(weight.size(1), weight_nbit);
233 expected_output_size[out.dim() - 1] = embedding_dim;
234
235 exec_aten::ArrayRef<exec_aten::SizesType> output_size{
236 expected_output_size, static_cast<size_t>(out.dim())};
237
238 torch::executor::Error err = resize_tensor(out, output_size);
239 ET_CHECK_MSG(
240 err == torch::executor::Error::Ok,
241 "Failed to resize out Tensor in quantized_embedding_xbit_out");
242 }
243
244 } // namespace
245
246 /**
247 * Retrieves the embeddings specified by indices, dequantizes them, and stores
248 * them in out. The weight is quantized per channel, with a scale and zero_point
249 * for each embedding.
250 *
251 * Corresponds as the out variant to torch.ops.quantized.embedding_xbit
252 *
253 * NOTE: quant_min, quant_max, and Dtype are not used in computation, but rather
254 * metadata that is passed around which can be useful for pattern matching. See
255 * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
256 * info.
257 */
quantized_embedding_xbit_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)258 Tensor& quantized_embedding_xbit_out(
259 // TODO Evaluate whether this name is appropriate for an operator that takes
260 // non quant input and returns fp output
261 const Tensor& weight,
262 const Tensor& weight_scales,
263 const exec_aten::optional<Tensor>& opt_weight_zero_points,
264 const int64_t weight_quant_min,
265 const int64_t weight_quant_max,
266 const Tensor& indices,
267 Tensor& out,
268 int weight_nbit) {
269 ScalarType out_type = out.scalar_type();
270
271 // TODO (jakeszwe): improve these to account for the size of out in relation
272 // to weight and indices accounting for a possible batch dimension
273 check_embedding_xbit_args(
274 weight,
275 weight_scales,
276 opt_weight_zero_points,
277 weight_quant_min,
278 weight_quant_max,
279 indices,
280 out_type,
281 out,
282 weight_nbit);
283
284 constexpr auto name = "quantized_decomposed::embedding_xbit.out";
285 ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
286 embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
287 weight,
288 weight_scales,
289 opt_weight_zero_points,
290 indices,
291 out,
292 weight_nbit);
293 });
294
295 return out;
296 }
297
quantized_embedding_xbit_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,Tensor & out,int weight_nbit)298 Tensor& quantized_embedding_xbit_out(
299 KernelRuntimeContext& context,
300 const Tensor& weight,
301 const Tensor& weight_scales,
302 const exec_aten::optional<Tensor>& opt_weight_zero_points,
303 int64_t weight_quant_min,
304 int64_t weight_quant_max,
305 const Tensor& indices,
306 Tensor& out,
307 int weight_nbit) {
308 // TODO(larryliu): Add a context arg to the real op function and remove this
309 // wrapper
310 (void)context;
311 resize_out_tensor(weight, indices, out, weight_nbit);
312 return quantized_embedding_xbit_out(
313 weight,
314 weight_scales,
315 opt_weight_zero_points,
316 weight_quant_min,
317 weight_quant_max,
318 indices,
319 out,
320 weight_nbit);
321 }
322
quantized_embedding_xbit_dtype_out(const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,const int64_t weight_quant_min,const int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)323 Tensor& quantized_embedding_xbit_dtype_out(
324 // TODO Evaluate whether this name is appropriate for an operator that takes
325 // non quant input and returns fp output
326 const Tensor& weight,
327 const Tensor& weight_scales,
328 const exec_aten::optional<Tensor>& opt_weight_zero_points,
329 const int64_t weight_quant_min,
330 const int64_t weight_quant_max,
331 const Tensor& indices,
332 exec_aten::optional<ScalarType> out_dtype,
333 Tensor& out,
334 int weight_nbit) {
335 // TODO (jakeszwe): improve these to account for the size of out in relation
336 // to weight and indices accounting for a possible batch dimension
337 check_embedding_xbit_args(
338 weight,
339 weight_scales,
340 opt_weight_zero_points,
341 weight_quant_min,
342 weight_quant_max,
343 indices,
344 out_dtype,
345 out,
346 weight_nbit);
347
348 ScalarType params_type = weight_scales.scalar_type();
349 ScalarType out_type = out.scalar_type();
350
351 constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out";
352 ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
353 ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
354 embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
355 weight,
356 weight_scales,
357 opt_weight_zero_points,
358 indices,
359 out,
360 weight_nbit);
361 });
362 });
363
364 return out;
365 }
366
quantized_embedding_xbit_dtype_out(KernelRuntimeContext & context,const Tensor & weight,const Tensor & weight_scales,const exec_aten::optional<Tensor> & opt_weight_zero_points,int64_t weight_quant_min,int64_t weight_quant_max,const Tensor & indices,exec_aten::optional<ScalarType> out_dtype,Tensor & out,int weight_nbit)367 Tensor& quantized_embedding_xbit_dtype_out(
368 KernelRuntimeContext& context,
369 const Tensor& weight,
370 const Tensor& weight_scales,
371 const exec_aten::optional<Tensor>& opt_weight_zero_points,
372 int64_t weight_quant_min,
373 int64_t weight_quant_max,
374 const Tensor& indices,
375 exec_aten::optional<ScalarType> out_dtype,
376 Tensor& out,
377 int weight_nbit) {
378 // TODO(larryliu): Add a context arg to the real op function and remove this
379 // wrapper
380 (void)context;
381 resize_out_tensor(weight, indices, out, weight_nbit);
382 return quantized_embedding_xbit_dtype_out(
383 weight,
384 weight_scales,
385 opt_weight_zero_points,
386 weight_quant_min,
387 weight_quant_max,
388 indices,
389 out_dtype,
390 out,
391 weight_nbit);
392 }
393
394 } // namespace native
395 } // namespace executor
396 } // namespace torch
397