1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10 11import torch 12from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 13 14 15class QuantizePerTokenTest(unittest.TestCase): 16 17 def test_quantize_per_token(self): 18 input_tensor = torch.tensor( 19 [[-0.5, 0.3, 1.2], [0.1, -0.8, 2.1], [-5, 1, 2]], dtype=torch.float32 20 ) 21 scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64) 22 scale = scale.unsqueeze(-1) 23 zero_point = torch.tensor([-1, -2, 0]) 24 zero_point = zero_point.unsqueeze(-1) 25 quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( 26 input_tensor, scale, zero_point, -128, 127, torch.int8 27 ) 28 expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( 29 input_tensor, scale, zero_point, -128, 127, torch.int8 30 ) 31 32 self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) 33 34 def test_quantize_per_token_large_tensor(self): 35 input_tensor = torch.rand((8, 32)) 36 scale = torch.rand((8, 1), dtype=torch.float64) 37 zero_point = torch.randint(0, 10, (8, 1)) 38 quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( 39 input_tensor, scale, zero_point, -128, 127, torch.int8 40 ) 41 expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( 42 input_tensor, scale, zero_point, -128, 127, torch.int8 43 ) 44 45 self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) 46 47 def test_quantize_per_token_high_rank(self): 48 input_tensor = torch.rand((1, 3, 8, 32)) 49 scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) 50 zero_point = torch.randint(0, 10, (1, 3, 8, 1)) 51 quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( 52 input_tensor, scale, zero_point, -128, 127, torch.int8 53 ) 54 expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( 55 input_tensor, scale, zero_point, -128, 127, torch.int8 56 ) 57 58 self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) 59 60 def test_quantize_per_token_dynamic(self): 61 input_tensor = torch.rand((1, 1, 8, 1)) 62 scale = torch.rand((1, 1, 8, 1), dtype=torch.float64) 63 zero_point = torch.randint(0, 10, (1, 1, 8, 1)) 64 quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( 65 input_tensor, scale, zero_point, -128, 127, torch.int8 66 ) 67 expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( 68 input_tensor, scale, zero_point, -128, 127, torch.int8 69 ) 70 71 self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) 72 73 input_tensor = torch.rand((1, 3, 8, 1)) 74 scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) 75 zero_point = torch.randint(0, 10, (1, 3, 8, 1)) 76 quantized_tensor = torch.ops.quantized_decomposed.quantize_per_token( 77 input_tensor, scale, zero_point, -128, 127, torch.int8 78 ) 79 expected_quantized_tensor = torch.ops.et_quant_test.quantize_per_token( 80 input_tensor, scale, zero_point, -128, 127, torch.int8 81 ) 82 83 self.assertTrue(torch.equal(quantized_tensor, expected_quantized_tensor)) 84 85 def test_dequantize_per_token(self): 86 input_tensor = torch.randint(-50, 120, (3, 3), dtype=torch.int8) 87 scale = torch.tensor([0.5, 0.8, 1.0], dtype=torch.float64) 88 scale = scale.unsqueeze(-1) 89 zero_point = torch.tensor([-1, -2, 0]) 90 zero_point = zero_point.unsqueeze(-1) 91 dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( 92 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 93 ) 94 expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( 95 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 96 ) 97 98 self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) 99 100 def test_dequantize_per_token_large_tensor(self): 101 input_tensor = torch.randint(-50, 120, (8, 32), dtype=torch.int8) 102 scale = torch.rand((8, 1), dtype=torch.float64) 103 zero_point = torch.randint(0, 10, (8, 1)) 104 dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( 105 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 106 ) 107 expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( 108 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 109 ) 110 111 self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) 112 113 def test_dequantize_per_token_high_rank(self): 114 input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8) 115 scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) 116 zero_point = torch.randint(0, 10, (1, 3, 8, 1)) 117 dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( 118 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 119 ) 120 expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( 121 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 122 ) 123 124 self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) 125 126 def test_dequantize_per_token_dynamic(self): 127 input_tensor = torch.randint(-50, 120, (1, 1, 8, 32), dtype=torch.int8) 128 scale = torch.rand((1, 1, 8, 1), dtype=torch.float64) 129 zero_point = torch.randint(0, 10, (1, 1, 8, 1)) 130 dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( 131 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 132 ) 133 expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( 134 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 135 ) 136 137 self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) 138 139 input_tensor = torch.randint(-50, 120, (1, 3, 8, 32), dtype=torch.int8) 140 scale = torch.rand((1, 3, 8, 1), dtype=torch.float64) 141 zero_point = torch.randint(0, 10, (1, 3, 8, 1)) 142 dequantized_tensor = torch.ops.quantized_decomposed.dequantize_per_token( 143 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 144 ) 145 expected_dequantized_tensor = torch.ops.et_quant_test.dequantize_per_token( 146 input_tensor, scale, zero_point, -128, 127, torch.int8, torch.float32 147 ) 148 149 self.assertTrue(torch.allclose(dequantized_tensor, expected_dequantized_tensor)) 150