1# Owner(s): ["oncall: quantization"] 2 3import torch 4from torch.testing._internal.common_utils import TestCase 5from torch.ao.quantization.utils import get_fqn_to_example_inputs 6from torch.ao.nn.quantized.modules.utils import _quantize_weight 7from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver 8 9 10class TestUtils(TestCase): 11 def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim): 12 m = M().eval() 13 fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs) 14 for fqn, expected_dims in expected_fqn_to_dim.items(): 15 assert fqn in expected_fqn_to_dim 16 example_inputs = fqn_to_example_inputs[fqn] 17 for example_input, expected_dim in zip(example_inputs, expected_dims): 18 assert example_input.dim() == expected_dim 19 20 def test_get_fqn_to_example_inputs_simple(self): 21 class Sub(torch.nn.Module): 22 def __init__(self) -> None: 23 super().__init__() 24 self.linear1 = torch.nn.Linear(5, 5) 25 self.linear2 = torch.nn.Linear(5, 5) 26 27 def forward(self, x): 28 x = self.linear1(x) 29 x = self.linear2(x) 30 return x 31 32 class M(torch.nn.Module): 33 def __init__(self) -> None: 34 super().__init__() 35 self.linear1 = torch.nn.Linear(5, 5) 36 self.linear2 = torch.nn.Linear(5, 5) 37 self.sub = Sub() 38 39 def forward(self, x): 40 x = self.linear1(x) 41 x = self.linear2(x) 42 x = self.sub(x) 43 return x 44 45 expected_fqn_to_dim = { 46 "": (2,), 47 "linear1": (2,), 48 "linear2": (2,), 49 "sub": (2,), 50 "sub.linear1": (2,), 51 "sub.linear2": (2,) 52 } 53 example_inputs = (torch.rand(1, 5),) 54 self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim) 55 56 def test_get_fqn_to_example_inputs_default_kwargs(self): 57 """ Test that we can get example inputs for functions with default keyword arguments 58 """ 59 class Sub(torch.nn.Module): 60 def __init__(self) -> None: 61 super().__init__() 62 self.linear1 = torch.nn.Linear(5, 5) 63 self.linear2 = torch.nn.Linear(5, 5) 64 65 def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)): 66 x = self.linear1(x) 67 x = self.linear2(x) 68 return x 69 70 class M(torch.nn.Module): 71 def __init__(self) -> None: 72 super().__init__() 73 self.linear1 = torch.nn.Linear(5, 5) 74 self.linear2 = torch.nn.Linear(5, 5) 75 self.sub = Sub() 76 77 def forward(self, x): 78 x = self.linear1(x) 79 x = self.linear2(x) 80 # only override `key2`, `key1` will use default 81 x = self.sub(x, key2=torch.rand(1, 2)) 82 return x 83 84 expected_fqn_to_dim = { 85 "": (2,), 86 "linear1": (2,), 87 "linear2": (2,), 88 # second arg is `key1`, which is using default argument 89 # third arg is `key2`, override by callsite 90 "sub": (2, 1, 2), 91 "sub.linear1": (2,), 92 "sub.linear2": (2,) 93 } 94 example_inputs = (torch.rand(1, 5),) 95 self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim) 96 97 def test_get_fqn_to_example_inputs_complex_args(self): 98 """ Test that we can record complex example inputs such as lists and dicts 99 """ 100 class Sub(torch.nn.Module): 101 def __init__(self) -> None: 102 super().__init__() 103 self.linear1 = torch.nn.Linear(5, 5) 104 self.linear2 = torch.nn.Linear(5, 5) 105 106 def forward(self, x, list_arg, dict_arg): 107 x = self.linear1(x) 108 x = self.linear2(x) 109 return x 110 111 class M(torch.nn.Module): 112 def __init__(self) -> None: 113 super().__init__() 114 self.linear1 = torch.nn.Linear(5, 5) 115 self.linear2 = torch.nn.Linear(5, 5) 116 self.sub = Sub() 117 118 def forward(self, x): 119 x = self.linear1(x) 120 x = self.linear2(x) 121 x = self.sub(x, [x], {"3": x}) 122 return x 123 124 example_inputs = (torch.rand(1, 5),) 125 m = M().eval() 126 fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs) 127 assert "sub" in fqn_to_example_inputs 128 assert isinstance(fqn_to_example_inputs["sub"][1], list) 129 assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \ 130 "3" in fqn_to_example_inputs["sub"][2] 131 132 def test_quantize_weight_clamping_per_tensor(self): 133 """ Test quant_{min, max} from per tensor observer is honored by `_quantize_weight` method 134 """ 135 fp_min, fp_max = -1000.0, 1000.0 136 q8_min, q8_max = -10, 10 137 138 float_tensor = torch.tensor([fp_min, fp_max]) 139 140 observer = MovingAverageMinMaxObserver( 141 averaging_constant=1.0, 142 dtype=torch.qint8, 143 quant_min=q8_min, 144 quant_max=q8_max, 145 qscheme=torch.per_tensor_symmetric, 146 ) 147 148 observer(float_tensor) 149 assert observer.min_val == fp_min 150 assert observer.max_val == fp_max 151 152 quantized_tensor = _quantize_weight(float_tensor, observer) 153 assert quantized_tensor.int_repr().max().item() == q8_max 154 assert quantized_tensor.int_repr().min().item() == q8_min 155 156 # Actual weight values can be outside than observer [min_val, max_val] for the moving average observer 157 float_tensor *= 1.2 158 159 quantized_tensor = _quantize_weight(float_tensor, observer) 160 assert quantized_tensor.int_repr().max().item() == q8_max 161 assert quantized_tensor.int_repr().min().item() == q8_min 162 163 def test_quantize_weight_clamping_per_channel(self): 164 """ Test quant_{min, max} from per channel observer is honored by `_quantize_weight` method 165 """ 166 fp_min, fp_max = -1000.0, 1000.0 167 q8_min, q8_max = -10, 10 168 169 float_tensor = torch.tensor([[fp_min, fp_max]]) 170 171 observer = MovingAveragePerChannelMinMaxObserver( 172 averaging_constant=1.0, 173 dtype=torch.qint8, 174 quant_min=q8_min, 175 quant_max=q8_max, 176 qscheme=torch.per_channel_symmetric, 177 ch_axis=0, 178 ) 179 180 observer(float_tensor) 181 assert observer.min_val == fp_min 182 assert observer.max_val == fp_max 183 184 quantized_tensor = _quantize_weight(float_tensor, observer) 185 assert quantized_tensor.int_repr().max().item() == q8_max 186 assert quantized_tensor.int_repr().min().item() == q8_min 187 188 # Actual weight values can be outside than observer [min_val, max_val] for the moving average observer 189 float_tensor *= 1.2 190 191 quantized_tensor = _quantize_weight(float_tensor, observer) 192 assert quantized_tensor.int_repr().max().item() == q8_max 193 assert quantized_tensor.int_repr().min().item() == q8_min 194 195 def test_uint1_7_dtype(self): 196 197 def up_size(size): 198 return (*size[:-1], size[-1] * 2) 199 200 class UInt4Tensor(torch.Tensor): 201 @staticmethod 202 def __new__(cls, elem, **kwargs): 203 assert elem.dtype is torch.uint8 204 assert not kwargs.get("requires_grad", False) 205 kwargs["requires_grad"] = False 206 return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) 207 208 def __init__(self, elem): 209 self.elem = elem 210 211 @classmethod 212 def __torch_dispatch__(cls, func, types, args, kwargs=None): 213 pass 214 215 # make sure it runs 216 x = UInt4Tensor(torch.tensor([ 217 [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 218 [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 219 [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], 220 ], dtype=torch.uint8)) 221 assert x.dtype == torch.uint4 222