1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16 #include <executorch/test/utils/DeathTest.h>
17
18 #include <gtest/gtest.h>
19 #include <limits>
20
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::optional;
24 using exec_aten::Scalar;
25 using exec_aten::ScalarType;
26 using exec_aten::Tensor;
27 using executorch::runtime::KernelRuntimeContext;
28 using torch::executor::native::dequantize_per_tensor_out;
29 using torch::executor::native::embedding_out;
30 using torch::executor::native::quantize_per_tensor_out;
31 using torch::executor::native::quantized_embedding_byte_out;
32
33 using torch::executor::testing::TensorFactory;
34
35 /// A generic smoke test that works for any dtype that supports ones() and
36 /// zeros().
37 template <exec_aten::ScalarType DTYPE>
test_dtype()38 void test_dtype() {
39 TensorFactory<ScalarType::Float> tf;
40 TensorFactory<ScalarType::Long> tf_l;
41
42 float scale = 0.5;
43 float zero_point = 1;
44 int64_t quant_min = 0;
45 int64_t quant_max = 255;
46
47 // clang-format off
48 Tensor weight = tf.make({3, 2}, {3.5, 2.0,
49 4, 1,
50 5.5, 13.2});
51 // clang-format on
52 // TODO make these different per dimension once per channel quant ops
53 // available
54 Tensor weight_scales = tf.full({3}, scale);
55 Tensor weight_zero_points = tf.full({3}, zero_point);
56
57 Tensor indices = tf_l.make({2}, {0, 2});
58
59 Tensor out = tf.zeros({2, 2});
60
61 TensorFactory<DTYPE> tfo;
62 Tensor qweight = tfo.zeros({3, 2});
63
64 // 3.5 / 0.5 + 1 = 8
65 // 2 / 0.5 + 1 = 5
66 // 4 / 0.5 + 1 = 9
67 // 1 / 0.5 + 1 = 3
68 // 5.5 / 0.5 + 1 = 12
69 // 13.2 / 0.5 + 1 = 27
70 quantize_per_tensor_out(
71 weight, scale, (float)zero_point, quant_min, quant_max, DTYPE, qweight);
72
73 quantized_embedding_byte_out(
74 qweight,
75 weight_scales,
76 weight_zero_points,
77 quant_min,
78 quant_max,
79 indices,
80 out);
81
82 // (8 - 1) * 0.5 = 3.5
83 // (5 - 1) * 0.5 = 2.0
84 // (12 - 1) * 0.5 = 5.5
85 // (27 - 1) * 0.5 = 13
86 // clang-format off
87 Tensor expected = tf.make({2, 2}, {3.5, 2,
88 5.5, 13});
89 // clang-format on
90
91 EXPECT_TENSOR_EQ(out, expected);
92 }
93
TEST(OpQuantizedEmbeddingTest,AllDtypesSupported)94 TEST(OpQuantizedEmbeddingTest, AllDtypesSupported) {
95 test_dtype<ScalarType::Byte>();
96 }
97
98 // Q -> DQ -> FP Embedding should be == to Q -> QEmbedding Bytes
TEST(OpQuantizedEmbeddingTest,ConsitencyWithReferencePattern)99 TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
100 TensorFactory<ScalarType::Float> tf;
101 TensorFactory<ScalarType::Int> tf_i;
102 TensorFactory<ScalarType::Long> tf_l;
103
104 float scale = 0.5;
105 float zero_point = 1;
106 int64_t quant_min = 0;
107 int64_t quant_max = 255;
108
109 // Do Q -> QEmbedding Bytes
110 Tensor weight = tf.make({3, 1}, {3.5, 5.5, 1.0});
111 // TODO make these different per dimension once per channel quant ops
112 // available
113 Tensor weight_scales = tf.full({3}, scale);
114 Tensor weight_zero_points = tf.full({3}, zero_point);
115
116 Tensor indices = tf_l.make({2}, {0, 2});
117
118 Tensor out = tf.zeros({2, 1});
119 Tensor fp_out = tf.zeros({2, 1});
120
121 TensorFactory<ScalarType::Byte> tfo;
122 Tensor qweight = tfo.zeros({3, 1});
123 KernelRuntimeContext context{};
124 // 3.5 / 0.5 + 1 = 8
125 // 5.5 / 0.5 + 1 = 12
126 // 1 / 0.5 + 1 = 3
127 quantize_per_tensor_out(
128 weight,
129 scale,
130 (int64_t)zero_point,
131 quant_min,
132 quant_max,
133 ScalarType::Byte,
134 qweight);
135
136 quantized_embedding_byte_out(
137 qweight,
138 weight_scales,
139 weight_zero_points,
140 quant_min,
141 quant_max,
142 indices,
143 out);
144
145 // Do Q DQ embedding
146 dequantize_per_tensor_out(
147 qweight,
148 scale,
149 (int64_t)zero_point,
150 quant_min,
151 quant_max,
152 ScalarType::Byte,
153 optional<ScalarType>(),
154 weight);
155
156 embedding_out(
157 context,
158 weight,
159 indices,
160 /*padding_idx=*/0,
161 /*scale_grad_by_freq=*/false,
162 /*sparse=*/false,
163 fp_out);
164
165 // can lossessly dq here so retrive the full information
166 // (8 - 1) * 0.5 = 3.5
167 // (3 - 1) * 0.5 = 1
168 Tensor expected = tf.make({2, 1}, {3.5, 1});
169 EXPECT_TENSOR_EQ(out, fp_out);
170 EXPECT_TENSOR_EQ(out, expected);
171 }
172
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbedding)173 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
174 et_pal_init();
175 TensorFactory<ScalarType::Float> tf;
176 TensorFactory<ScalarType::Int> tf_i;
177 TensorFactory<ScalarType::Long> tf_l;
178
179 int64_t quant_min = 0;
180 int64_t quant_max = 255;
181
182 Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
183 Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
184 TensorFactory<ScalarType::Byte> tfo;
185 Tensor qweight =
186 tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
187
188 Tensor indices = tf_l.make({3}, {0, 2, 1});
189
190 Tensor out = tf.zeros({3, 4});
191 Tensor expected = tf.make(
192 {3, 4}, {3.5, 4.5, 5.5, 6.5, 1.5, 3.0, 4.5, 7.5, 5.0, 7.0, 7.0, 9.0});
193
194 quantized_embedding_byte_out(
195 qweight,
196 weight_scales,
197 weight_zero_points,
198 quant_min,
199 quant_max,
200 indices,
201 out);
202
203 EXPECT_TENSOR_EQ(out, expected);
204
205 // Groupwise quantization. groupsize = 2
206 weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
207 weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
208 /*
209 fp_weight = [3.5, 4.5, 7, 9,
210 4.5, 7.5, 6, 10,
211 -7.5, -5.0, -9.0, -3.0]
212 */
213
214 out = tf.zeros({3, 4});
215 expected = tf.make(
216 {3, 4}, {3.5, 4.5, 7, 9, -7.5, -5.0, -9.0, -3.0, 4.5, 7.5, 6, 10});
217
218 quantized_embedding_byte_out(
219 qweight,
220 weight_scales,
221 weight_zero_points,
222 quant_min,
223 quant_max,
224 indices,
225 out);
226
227 EXPECT_TENSOR_EQ(out, expected);
228 }
229
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath1)230 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
231 et_pal_init();
232 TensorFactory<ScalarType::Float> tf;
233 TensorFactory<ScalarType::Int> tf_i;
234 TensorFactory<ScalarType::Long> tf_l;
235
236 int64_t quant_min = 0;
237 int64_t quant_max = 255;
238
239 Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
240 Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
241 TensorFactory<ScalarType::Byte> tfo;
242 Tensor qweight =
243 tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
244
245 Tensor indices = tf_l.make({3}, {0, 2, 1});
246
247 Tensor out = tf.zeros({3, 4});
248 ET_EXPECT_DEATH(
249 quantized_embedding_byte_out(
250 qweight,
251 weight_scales,
252 weight_zero_points,
253 quant_min,
254 quant_max,
255 indices,
256 out),
257 "");
258 }
259
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath2)260 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
261 et_pal_init();
262 TensorFactory<ScalarType::Float> tf;
263 TensorFactory<ScalarType::Int> tf_i;
264 TensorFactory<ScalarType::Long> tf_l;
265
266 int64_t quant_min = 0;
267 int64_t quant_max = 255;
268
269 Tensor weight_scales = tf.make({2}, {0.5, 1.0});
270 Tensor weight_zero_points = tf.make({2}, {1, 5});
271 TensorFactory<ScalarType::Byte> tfo;
272 Tensor qweight =
273 tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
274
275 Tensor indices = tf_l.make({3}, {0, 2, 1});
276
277 Tensor out = tf.zeros({3, 4});
278 ET_EXPECT_DEATH(
279 quantized_embedding_byte_out(
280 qweight,
281 weight_scales,
282 weight_zero_points,
283 quant_min,
284 quant_max,
285 indices,
286 out),
287 "");
288 }
289
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath3)290 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
291 et_pal_init();
292 TensorFactory<ScalarType::Float> tf;
293 TensorFactory<ScalarType::Int> tf_i;
294 TensorFactory<ScalarType::Long> tf_l;
295
296 int64_t quant_min = 0;
297 int64_t quant_max = 255;
298
299 Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
300 Tensor weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
301 TensorFactory<ScalarType::Byte> tfo;
302 Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
303
304 Tensor indices = tf_l.make({3}, {0, 2, 1});
305
306 Tensor out = tf.zeros({3, 3});
307 ET_EXPECT_DEATH(
308 quantized_embedding_byte_out(
309 qweight,
310 weight_scales,
311 weight_zero_points,
312 quant_min,
313 quant_max,
314 indices,
315 out),
316 "");
317 }
318
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath4)319 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
320 et_pal_init();
321 TensorFactory<ScalarType::Float> tf;
322 TensorFactory<ScalarType::Int> tf_i;
323 TensorFactory<ScalarType::Long> tf_l;
324
325 int64_t quant_min = 0;
326 int64_t quant_max = 255;
327
328 Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
329 Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
330 TensorFactory<ScalarType::Byte> tfo;
331 Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
332
333 Tensor indices = tf_l.make({3}, {0, 2, 1});
334
335 Tensor out = tf.zeros({3, 3});
336 ET_EXPECT_DEATH(
337 quantized_embedding_byte_out(
338 qweight,
339 weight_scales,
340 weight_zero_points,
341 quant_min,
342 quant_max,
343 indices,
344 out),
345 "");
346 }
347
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath5)348 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
349 et_pal_init();
350 TensorFactory<ScalarType::Float> tf;
351 TensorFactory<ScalarType::Int> tf_i;
352 TensorFactory<ScalarType::Long> tf_l;
353
354 int64_t quant_min = 0;
355 int64_t quant_max = 255;
356
357 Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
358 Tensor weight_zero_points = tf.make({3, 3}, {1, 5, 7, 1, 5, 7, 1, 5, 7});
359 TensorFactory<ScalarType::Byte> tfo;
360 Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
361
362 Tensor indices = tf_l.make({3}, {0, 2, 1});
363
364 Tensor out = tf.zeros({3, 3});
365 ET_EXPECT_DEATH(
366 quantized_embedding_byte_out(
367 qweight,
368 weight_scales,
369 weight_zero_points,
370 quant_min,
371 quant_max,
372 indices,
373 out),
374 "");
375 }
376