xref: /aosp_15_r20/external/executorch/kernels/quantized/test/test_quant_dequant_per_token.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10
11import torch
12from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
13
14
15class QuantizePerTokenTest(unittest.TestCase):
16
17    def test_quantize_per_token(self):
18        input_tensor = torch.tensor(
19            [[-0.5, 0.3, 1.2], [0.1, -0.8, 2.1], [-5, 1, 2]], dtype=torch.float32
20        )
21        scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64)
22        scale = scale.unsqueeze(-1)
23        zero_point = torch.tensor([-1, -2, 0])
24        zero_point = zero_point.unsqueeze(-1)
25        quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token(
26            input_tensor, scale, zero_point, -128, 127, torch.int8
27        )
28        expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token(
29            input_tensor, scale, zero_point, -128, 127, torch.int8
30        )
31
32        self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor))
33
34    def test_quantize_per_token_large_tensor(self):
35        input_tensor = torch.rand((8, 32))
36        scale = torch.rand((8, 1), dtype=torch.float64)
37        zero_point = torch.randint(0, 10, (8, 1))
38        quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token(
39            input_tensor, scale, zero_point, -128, 127, torch.int8
40        )
41        expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token(
42            input_tensor, scale, zero_point, -128, 127, torch.int8
43        )
44
45        self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor))
46
47    def test_quantize_per_token_high_rank(self):
48        input_tensor = torch.rand((1, 3, 8, 32))
49        scale = torch.rand((1, 3, 8, 1), dtype=torch.float64)
50        zero_point = torch.randint(0, 10, (1, 3, 8, 1))
51        quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token(
52            input_tensor, scale, zero_point, -128, 127, torch.int8
53        )
54        expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token(
55            input_tensor, scale, zero_point, -128, 127, torch.int8
56        )
57
58        self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor))
59
60    def test_quantize_per_token_dynamic(self):
61        input_tensor = torch.rand((1, 1, 8, 1))
62        scale = torch.rand((1, 1, 8, 1), dtype=torch.float64)
63        zero_point = torch.randint(0, 10, (1, 1, 8, 1))
64        quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token(
65            input_tensor, scale, zero_point, -128, 127, torch.int8
66        )
67        expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token(
68            input_tensor, scale, zero_point, -128, 127, torch.int8
69        )
70
71        self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor))
72
73        input_tensor = torch.rand((1, 3, 8, 1))
74        scale = torch.rand((1, 3, 8, 1), dtype=torch.float64)
75        zero_point = torch.randint(0, 10, (1, 3, 8, 1))
76        quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token(
77            input_tensor, scale, zero_point, -128, 127, torch.int8
78        )
79        expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token(
80            input_tensor, scale, zero_point, -128, 127, torch.int8
81        )
82
83        self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor))
84
85    def test_dequantize_per_token(self):
86        input_tensor = torch.randint(-50, 120, (3, 3), dtype=torch.int8)
87        scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64)
88        scale = scale.unsqueeze(-1)
89        zero_point = torch.tensor([-1, -2, 0])
90        zero_point = zero_point.unsqueeze(-1)
91        dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token(
92            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
93        )
94        expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token(
95            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
96        )
97
98        self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))
99
100    def test_dequantize_per_token_large_tensor(self):
101        input_tensor = torch.randint(-50, 120, (8, 32), dtype=torch.int8)
102        scale = torch.rand((8, 1), dtype=torch.float64)
103        zero_point = torch.randint(0, 10, (8, 1))
104        dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token(
105            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
106        )
107        expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token(
108            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
109        )
110
111        self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))
112
113    def test_dequantize_per_token_high_rank(self):
114        input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8)
115        scale = torch.rand((1, 3, 8, 1), dtype=torch.float64)
116        zero_point = torch.randint(0, 10, (1, 3, 8, 1))
117        dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token(
118            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
119        )
120        expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token(
121            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
122        )
123
124        self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))
125
126    def test_dequantize_per_token_dynamic(self):
127        input_tensor = torch.randint(-50, 120, (1, 1, 8, 32), dtype=torch.int8)
128        scale = torch.rand((1, 1, 8, 1), dtype=torch.float64)
129        zero_point = torch.randint(0, 10, (1, 1, 8, 1))
130        dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token(
131            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
132        )
133        expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token(
134            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
135        )
136
137        self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))
138
139        input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8)
140        scale = torch.rand((1, 3, 8, 1), dtype=torch.float64)
141        zero_point = torch.randint(0, 10, (1, 3, 8, 1))
142        dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token(
143            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
144        )
145        expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token(
146            input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32
147        )
148
149        self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor))
150