xref: /aosp_15_r20/external/executorch/kernels/quantized/test/op_embedding_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/NativeFunctions.h> // Declares the aten operator
10 #include <executorch/kernels/quantized/NativeFunctions.h> // Declares the quantized operator
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/platform/runtime.h>
16 #include <executorch/test/utils/DeathTest.h>
17 
18 #include <gtest/gtest.h>
19 #include <limits>
20 
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::optional;
24 using exec_aten::Scalar;
25 using exec_aten::ScalarType;
26 using exec_aten::Tensor;
27 using executorch::runtime::KernelRuntimeContext;
28 using torch::executor::native::dequantize_per_tensor_out;
29 using torch::executor::native::embedding_out;
30 using torch::executor::native::quantize_per_tensor_out;
31 using torch::executor::native::quantized_embedding_byte_out;
32 
33 using torch::executor::testing::TensorFactory;
34 
35 /// A generic smoke test that works for any dtype that supports ones() and
36 /// zeros().
37 template <exec_aten::ScalarType DTYPE>
test_dtype()38 void test_dtype() {
39   TensorFactory<ScalarType::Float> tf;
40   TensorFactory<ScalarType::Long> tf_l;
41 
42   float scale = 0.5;
43   float zero_point = 1;
44   int64_t quant_min = 0;
45   int64_t quant_max = 255;
46 
47   // clang-format off
48   Tensor weight = tf.make({3, 2}, {3.5, 2.0,
49                                    4, 1,
50                                    5.5, 13.2});
51   // clang-format on
52   // TODO make these different per dimension once per channel quant ops
53   // available
54   Tensor weight_scales = tf.full({3}, scale);
55   Tensor weight_zero_points = tf.full({3}, zero_point);
56 
57   Tensor indices = tf_l.make({2}, {0, 2});
58 
59   Tensor out = tf.zeros({2, 2});
60 
61   TensorFactory<DTYPE> tfo;
62   Tensor qweight = tfo.zeros({3, 2});
63 
64   // 3.5 / 0.5 + 1 = 8
65   // 2 / 0.5 + 1 = 5
66   // 4 / 0.5 + 1 = 9
67   // 1 / 0.5 + 1 = 3
68   // 5.5 / 0.5 + 1 = 12
69   // 13.2 / 0.5 + 1 = 27
70   quantize_per_tensor_out(
71       weight, scale, (float)zero_point, quant_min, quant_max, DTYPE, qweight);
72 
73   quantized_embedding_byte_out(
74       qweight,
75       weight_scales,
76       weight_zero_points,
77       quant_min,
78       quant_max,
79       indices,
80       out);
81 
82   // (8 - 1) * 0.5 = 3.5
83   // (5 - 1) * 0.5 = 2.0
84   // (12 - 1) * 0.5 = 5.5
85   // (27 - 1) * 0.5 = 13
86   // clang-format off
87   Tensor expected = tf.make({2, 2}, {3.5, 2,
88                                       5.5, 13});
89   // clang-format on
90 
91   EXPECT_TENSOR_EQ(out, expected);
92 }
93 
TEST(OpQuantizedEmbeddingTest,AllDtypesSupported)94 TEST(OpQuantizedEmbeddingTest, AllDtypesSupported) {
95   test_dtype<ScalarType::Byte>();
96 }
97 
98 // Q -> DQ -> FP Embedding should be == to Q -> QEmbedding Bytes
TEST(OpQuantizedEmbeddingTest,ConsitencyWithReferencePattern)99 TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
100   TensorFactory<ScalarType::Float> tf;
101   TensorFactory<ScalarType::Int> tf_i;
102   TensorFactory<ScalarType::Long> tf_l;
103 
104   float scale = 0.5;
105   float zero_point = 1;
106   int64_t quant_min = 0;
107   int64_t quant_max = 255;
108 
109   // Do Q -> QEmbedding Bytes
110   Tensor weight = tf.make({3, 1}, {3.5, 5.5, 1.0});
111   // TODO make these different per dimension once per channel quant ops
112   // available
113   Tensor weight_scales = tf.full({3}, scale);
114   Tensor weight_zero_points = tf.full({3}, zero_point);
115 
116   Tensor indices = tf_l.make({2}, {0, 2});
117 
118   Tensor out = tf.zeros({2, 1});
119   Tensor fp_out = tf.zeros({2, 1});
120 
121   TensorFactory<ScalarType::Byte> tfo;
122   Tensor qweight = tfo.zeros({3, 1});
123   KernelRuntimeContext context{};
124   // 3.5 / 0.5 + 1 = 8
125   // 5.5 / 0.5 + 1 = 12
126   // 1 / 0.5 + 1 = 3
127   quantize_per_tensor_out(
128       weight,
129       scale,
130       (int64_t)zero_point,
131       quant_min,
132       quant_max,
133       ScalarType::Byte,
134       qweight);
135 
136   quantized_embedding_byte_out(
137       qweight,
138       weight_scales,
139       weight_zero_points,
140       quant_min,
141       quant_max,
142       indices,
143       out);
144 
145   // Do Q DQ embedding
146   dequantize_per_tensor_out(
147       qweight,
148       scale,
149       (int64_t)zero_point,
150       quant_min,
151       quant_max,
152       ScalarType::Byte,
153       optional<ScalarType>(),
154       weight);
155 
156   embedding_out(
157       context,
158       weight,
159       indices,
160       /*padding_idx=*/0,
161       /*scale_grad_by_freq=*/false,
162       /*sparse=*/false,
163       fp_out);
164 
165   // can lossessly dq here so retrive the full information
166   // (8 - 1) * 0.5 = 3.5
167   // (3 - 1) * 0.5 = 1
168   Tensor expected = tf.make({2, 1}, {3.5, 1});
169   EXPECT_TENSOR_EQ(out, fp_out);
170   EXPECT_TENSOR_EQ(out, expected);
171 }
172 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbedding)173 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
174   et_pal_init();
175   TensorFactory<ScalarType::Float> tf;
176   TensorFactory<ScalarType::Int> tf_i;
177   TensorFactory<ScalarType::Long> tf_l;
178 
179   int64_t quant_min = 0;
180   int64_t quant_max = 255;
181 
182   Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
183   Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
184   TensorFactory<ScalarType::Byte> tfo;
185   Tensor qweight =
186       tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
187 
188   Tensor indices = tf_l.make({3}, {0, 2, 1});
189 
190   Tensor out = tf.zeros({3, 4});
191   Tensor expected = tf.make(
192       {3, 4}, {3.5, 4.5, 5.5, 6.5, 1.5, 3.0, 4.5, 7.5, 5.0, 7.0, 7.0, 9.0});
193 
194   quantized_embedding_byte_out(
195       qweight,
196       weight_scales,
197       weight_zero_points,
198       quant_min,
199       quant_max,
200       indices,
201       out);
202 
203   EXPECT_TENSOR_EQ(out, expected);
204 
205   // Groupwise quantization. groupsize = 2
206   weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
207   weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
208   /*
209   fp_weight = [3.5, 4.5, 7, 9,
210                4.5, 7.5, 6, 10,
211                -7.5, -5.0, -9.0, -3.0]
212   */
213 
214   out = tf.zeros({3, 4});
215   expected = tf.make(
216       {3, 4}, {3.5, 4.5, 7, 9, -7.5, -5.0, -9.0, -3.0, 4.5, 7.5, 6, 10});
217 
218   quantized_embedding_byte_out(
219       qweight,
220       weight_scales,
221       weight_zero_points,
222       quant_min,
223       quant_max,
224       indices,
225       out);
226 
227   EXPECT_TENSOR_EQ(out, expected);
228 }
229 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath1)230 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
231   et_pal_init();
232   TensorFactory<ScalarType::Float> tf;
233   TensorFactory<ScalarType::Int> tf_i;
234   TensorFactory<ScalarType::Long> tf_l;
235 
236   int64_t quant_min = 0;
237   int64_t quant_max = 255;
238 
239   Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
240   Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
241   TensorFactory<ScalarType::Byte> tfo;
242   Tensor qweight =
243       tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
244 
245   Tensor indices = tf_l.make({3}, {0, 2, 1});
246 
247   Tensor out = tf.zeros({3, 4});
248   ET_EXPECT_DEATH(
249       quantized_embedding_byte_out(
250           qweight,
251           weight_scales,
252           weight_zero_points,
253           quant_min,
254           quant_max,
255           indices,
256           out),
257       "");
258 }
259 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath2)260 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
261   et_pal_init();
262   TensorFactory<ScalarType::Float> tf;
263   TensorFactory<ScalarType::Int> tf_i;
264   TensorFactory<ScalarType::Long> tf_l;
265 
266   int64_t quant_min = 0;
267   int64_t quant_max = 255;
268 
269   Tensor weight_scales = tf.make({2}, {0.5, 1.0});
270   Tensor weight_zero_points = tf.make({2}, {1, 5});
271   TensorFactory<ScalarType::Byte> tfo;
272   Tensor qweight =
273       tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
274 
275   Tensor indices = tf_l.make({3}, {0, 2, 1});
276 
277   Tensor out = tf.zeros({3, 4});
278   ET_EXPECT_DEATH(
279       quantized_embedding_byte_out(
280           qweight,
281           weight_scales,
282           weight_zero_points,
283           quant_min,
284           quant_max,
285           indices,
286           out),
287       "");
288 }
289 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath3)290 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
291   et_pal_init();
292   TensorFactory<ScalarType::Float> tf;
293   TensorFactory<ScalarType::Int> tf_i;
294   TensorFactory<ScalarType::Long> tf_l;
295 
296   int64_t quant_min = 0;
297   int64_t quant_max = 255;
298 
299   Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
300   Tensor weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
301   TensorFactory<ScalarType::Byte> tfo;
302   Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
303 
304   Tensor indices = tf_l.make({3}, {0, 2, 1});
305 
306   Tensor out = tf.zeros({3, 3});
307   ET_EXPECT_DEATH(
308       quantized_embedding_byte_out(
309           qweight,
310           weight_scales,
311           weight_zero_points,
312           quant_min,
313           quant_max,
314           indices,
315           out),
316       "");
317 }
318 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath4)319 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
320   et_pal_init();
321   TensorFactory<ScalarType::Float> tf;
322   TensorFactory<ScalarType::Int> tf_i;
323   TensorFactory<ScalarType::Long> tf_l;
324 
325   int64_t quant_min = 0;
326   int64_t quant_max = 255;
327 
328   Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
329   Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
330   TensorFactory<ScalarType::Byte> tfo;
331   Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
332 
333   Tensor indices = tf_l.make({3}, {0, 2, 1});
334 
335   Tensor out = tf.zeros({3, 3});
336   ET_EXPECT_DEATH(
337       quantized_embedding_byte_out(
338           qweight,
339           weight_scales,
340           weight_zero_points,
341           quant_min,
342           quant_max,
343           indices,
344           out),
345       "");
346 }
347 
TEST(OpQuantizedEmbeddingTest,TestGroupWiseQuantizedEmbeddingDeath5)348 TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
349   et_pal_init();
350   TensorFactory<ScalarType::Float> tf;
351   TensorFactory<ScalarType::Int> tf_i;
352   TensorFactory<ScalarType::Long> tf_l;
353 
354   int64_t quant_min = 0;
355   int64_t quant_max = 255;
356 
357   Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
358   Tensor weight_zero_points = tf.make({3, 3}, {1, 5, 7, 1, 5, 7, 1, 5, 7});
359   TensorFactory<ScalarType::Byte> tfo;
360   Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
361 
362   Tensor indices = tf_l.make({3}, {0, 2, 1});
363 
364   Tensor out = tf.zeros({3, 3});
365   ET_EXPECT_DEATH(
366       quantized_embedding_byte_out(
367           qweight,
368           weight_scales,
369           weight_zero_points,
370           quant_min,
371           quant_max,
372           indices,
373           out),
374       "");
375 }
376