1# Owner(s): ["oncall: quantization"] 2 3import os 4import sys 5import unittest 6from typing import Set 7 8# torch 9import torch 10import torch.ao.nn.intrinsic.quantized as nniq 11import torch.ao.nn.quantized as nnq 12import torch.ao.nn.quantized.dynamic as nnqd 13import torch.ao.quantization.quantize_fx as quantize_fx 14import torch.nn as nn 15from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver 16from torch.fx import GraphModule 17from torch.testing._internal.common_quantization import skipIfNoFBGEMM 18from torch.testing._internal.common_quantized import ( 19 override_qengines, 20 qengine_is_fbgemm, 21) 22 23# Testing utils 24from torch.testing._internal.common_utils import IS_AVX512_VNNI_SUPPORTED, TestCase 25from torch.testing._internal.quantization_torch_package_models import ( 26 LinearReluFunctional, 27) 28 29 30def remove_prefix(text, prefix): 31 if text.startswith(prefix): 32 return text[len(prefix) :] 33 return text 34 35 36def get_filenames(self, subname): 37 # NB: we take __file__ from the module that defined the test 38 # class, so we place the expect directory where the test script 39 # lives, NOT where test/common_utils.py lives. 40 module_id = self.__class__.__module__ 41 munged_id = remove_prefix(self.id(), module_id + ".") 42 test_file = os.path.realpath(sys.modules[module_id].__file__) 43 base_name = os.path.join(os.path.dirname(test_file), "../serialized", munged_id) 44 45 subname_output = "" 46 if subname: 47 base_name += "_" + subname 48 subname_output = f" ({subname})" 49 50 input_file = base_name + ".input.pt" 51 state_dict_file = base_name + ".state_dict.pt" 52 scripted_module_file = base_name + ".scripted.pt" 53 traced_module_file = base_name + ".traced.pt" 54 expected_file = base_name + ".expected.pt" 55 package_file = base_name + ".package.pt" 56 get_attr_targets_file = base_name + ".get_attr_targets.pt" 57 58 return ( 59 input_file, 60 state_dict_file, 61 scripted_module_file, 62 traced_module_file, 63 expected_file, 64 package_file, 65 get_attr_targets_file, 66 ) 67 68 69class TestSerialization(TestCase): 70 """Test backward compatiblity for serialization and numerics""" 71 72 # Copy and modified from TestCase.assertExpected 73 def _test_op( 74 self, 75 qmodule, 76 subname=None, 77 input_size=None, 78 input_quantized=True, 79 generate=False, 80 prec=None, 81 new_zipfile_serialization=False, 82 ): 83 r"""Test quantized modules serialized previously can be loaded 84 with current code, make sure we don't break backward compatibility for the 85 serialization of quantized modules 86 """ 87 ( 88 input_file, 89 state_dict_file, 90 scripted_module_file, 91 traced_module_file, 92 expected_file, 93 _package_file, 94 _get_attr_targets_file, 95 ) = get_filenames(self, subname) 96 97 # only generate once. 98 if generate and qengine_is_fbgemm(): 99 input_tensor = torch.rand(*input_size).float() 100 if input_quantized: 101 input_tensor = torch.quantize_per_tensor( 102 input_tensor, 0.5, 2, torch.quint8 103 ) 104 torch.save(input_tensor, input_file) 105 # Temporary fix to use _use_new_zipfile_serialization until #38379 lands. 106 torch.save( 107 qmodule.state_dict(), 108 state_dict_file, 109 _use_new_zipfile_serialization=new_zipfile_serialization, 110 ) 111 torch.jit.save(torch.jit.script(qmodule), scripted_module_file) 112 torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file) 113 torch.save(qmodule(input_tensor), expected_file) 114 115 input_tensor = torch.load(input_file) 116 # weights_only = False as sometimes get ScriptObject here 117 qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False)) 118 qmodule_scripted = torch.jit.load(scripted_module_file) 119 qmodule_traced = torch.jit.load(traced_module_file) 120 expected = torch.load(expected_file) 121 self.assertEqual(qmodule(input_tensor), expected, atol=prec) 122 self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) 123 self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) 124 125 def _test_op_graph( 126 self, 127 qmodule, 128 subname=None, 129 input_size=None, 130 input_quantized=True, 131 generate=False, 132 prec=None, 133 new_zipfile_serialization=False, 134 ): 135 r""" 136 Input: a floating point module 137 138 If generate == True, traces and scripts the module and quantizes the results with 139 PTQ, and saves the results. 140 141 If generate == False, traces and scripts the module and quantizes the results with 142 PTQ, and compares to saved results. 143 """ 144 ( 145 input_file, 146 state_dict_file, 147 scripted_module_file, 148 traced_module_file, 149 expected_file, 150 _package_file, 151 _get_attr_targets_file, 152 ) = get_filenames(self, subname) 153 154 # only generate once. 155 if generate and qengine_is_fbgemm(): 156 input_tensor = torch.rand(*input_size).float() 157 torch.save(input_tensor, input_file) 158 159 # convert to TorchScript 160 scripted = torch.jit.script(qmodule) 161 traced = torch.jit.trace(qmodule, input_tensor) 162 163 # quantize 164 165 def _eval_fn(model, data): 166 model(data) 167 168 qconfig_dict = {"": torch.ao.quantization.default_qconfig} 169 scripted_q = torch.ao.quantization.quantize_jit( 170 scripted, qconfig_dict, _eval_fn, [input_tensor] 171 ) 172 traced_q = torch.ao.quantization.quantize_jit( 173 traced, qconfig_dict, _eval_fn, [input_tensor] 174 ) 175 176 torch.jit.save(scripted_q, scripted_module_file) 177 torch.jit.save(traced_q, traced_module_file) 178 torch.save(scripted_q(input_tensor), expected_file) 179 180 input_tensor = torch.load(input_file) 181 qmodule_scripted = torch.jit.load(scripted_module_file) 182 qmodule_traced = torch.jit.load(traced_module_file) 183 expected = torch.load(expected_file) 184 self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) 185 self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) 186 187 def _test_obs( 188 self, obs, input_size, subname=None, generate=False, check_numerics=True 189 ): 190 """ 191 Test observer code can be loaded from state_dict. 192 """ 193 ( 194 input_file, 195 state_dict_file, 196 _, 197 traced_module_file, 198 expected_file, 199 _package_file, 200 _get_attr_targets_file, 201 ) = get_filenames(self, None) 202 if generate: 203 input_tensor = torch.rand(*input_size).float() 204 torch.save(input_tensor, input_file) 205 torch.save(obs(input_tensor), expected_file) 206 torch.save(obs.state_dict(), state_dict_file) 207 208 input_tensor = torch.load(input_file) 209 obs.load_state_dict(torch.load(state_dict_file)) 210 expected = torch.load(expected_file) 211 if check_numerics: 212 self.assertEqual(obs(input_tensor), expected) 213 214 def _test_package(self, fp32_module, input_size, generate=False): 215 """ 216 Verifies that files created in the past with torch.package 217 work on today's FX graph mode quantization transforms. 218 """ 219 ( 220 input_file, 221 state_dict_file, 222 _scripted_module_file, 223 _traced_module_file, 224 expected_file, 225 package_file, 226 get_attr_targets_file, 227 ) = get_filenames(self, None) 228 229 package_name = "test" 230 resource_name_model = "test.pkl" 231 232 def _do_quant_transforms( 233 m: torch.nn.Module, 234 input_tensor: torch.Tensor, 235 ) -> torch.nn.Module: 236 example_inputs = (input_tensor,) 237 # do the quantizaton transforms and save result 238 qconfig = torch.ao.quantization.get_default_qconfig("fbgemm") 239 mp = quantize_fx.prepare_fx(m, {"": qconfig}, example_inputs=example_inputs) 240 mp(input_tensor) 241 mq = quantize_fx.convert_fx(mp) 242 return mq 243 244 def _get_get_attr_target_strings(m: GraphModule) -> Set[str]: 245 results = set() 246 for node in m.graph.nodes: 247 if node.op == "get_attr": 248 results.add(node.target) 249 return results 250 251 if generate and qengine_is_fbgemm(): 252 input_tensor = torch.randn(*input_size) 253 torch.save(input_tensor, input_file) 254 255 # save the model with torch.package 256 with torch.package.PackageExporter(package_file) as exp: 257 exp.intern("torch.testing._internal.quantization_torch_package_models") 258 exp.save_pickle(package_name, resource_name_model, fp32_module) 259 260 # do the quantization transforms and save the result 261 mq = _do_quant_transforms(fp32_module, input_tensor) 262 get_attrs = _get_get_attr_target_strings(mq) 263 torch.save(get_attrs, get_attr_targets_file) 264 q_result = mq(input_tensor) 265 torch.save(q_result, expected_file) 266 267 # load input tensor 268 input_tensor = torch.load(input_file) 269 expected_output_tensor = torch.load(expected_file) 270 expected_get_attrs = torch.load(get_attr_targets_file, weights_only=False) 271 272 # load model from package and verify output and get_attr targets match 273 imp = torch.package.PackageImporter(package_file) 274 m = imp.load_pickle(package_name, resource_name_model) 275 mq = _do_quant_transforms(m, input_tensor) 276 277 get_attrs = _get_get_attr_target_strings(mq) 278 self.assertTrue( 279 get_attrs == expected_get_attrs, 280 f"get_attrs: expected {expected_get_attrs}, got {get_attrs}", 281 ) 282 output_tensor = mq(input_tensor) 283 self.assertTrue(torch.allclose(output_tensor, expected_output_tensor)) 284 285 @override_qengines 286 def test_linear(self): 287 module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8) 288 self._test_op(module, input_size=[1, 3], generate=False) 289 290 @override_qengines 291 def test_linear_relu(self): 292 module = nniq.LinearReLU(3, 1, bias=True, dtype=torch.qint8) 293 self._test_op(module, input_size=[1, 3], generate=False) 294 295 @override_qengines 296 def test_linear_dynamic(self): 297 module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8) 298 self._test_op( 299 module_qint8, 300 "qint8", 301 input_size=[1, 3], 302 input_quantized=False, 303 generate=False, 304 ) 305 if qengine_is_fbgemm(): 306 module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16) 307 self._test_op( 308 module_float16, 309 "float16", 310 input_size=[1, 3], 311 input_quantized=False, 312 generate=False, 313 ) 314 315 @override_qengines 316 def test_conv2d(self): 317 module = nnq.Conv2d( 318 3, 319 3, 320 kernel_size=3, 321 stride=1, 322 padding=0, 323 dilation=1, 324 groups=1, 325 bias=True, 326 padding_mode="zeros", 327 ) 328 self._test_op(module, input_size=[1, 3, 6, 6], generate=False) 329 330 @override_qengines 331 def test_conv2d_nobias(self): 332 module = nnq.Conv2d( 333 3, 334 3, 335 kernel_size=3, 336 stride=1, 337 padding=0, 338 dilation=1, 339 groups=1, 340 bias=False, 341 padding_mode="zeros", 342 ) 343 self._test_op(module, input_size=[1, 3, 6, 6], generate=False) 344 345 @override_qengines 346 def test_conv2d_graph(self): 347 module = nn.Sequential( 348 torch.ao.quantization.QuantStub(), 349 nn.Conv2d( 350 3, 351 3, 352 kernel_size=3, 353 stride=1, 354 padding=0, 355 dilation=1, 356 groups=1, 357 bias=True, 358 padding_mode="zeros", 359 ), 360 ) 361 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 362 363 @override_qengines 364 def test_conv2d_nobias_graph(self): 365 module = nn.Sequential( 366 torch.ao.quantization.QuantStub(), 367 nn.Conv2d( 368 3, 369 3, 370 kernel_size=3, 371 stride=1, 372 padding=0, 373 dilation=1, 374 groups=1, 375 bias=False, 376 padding_mode="zeros", 377 ), 378 ) 379 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 380 381 @override_qengines 382 def test_conv2d_graph_v2(self): 383 # tests the same thing as test_conv2d_graph, but for version 2 of 384 # ConvPackedParams{n}d 385 module = nn.Sequential( 386 torch.ao.quantization.QuantStub(), 387 nn.Conv2d( 388 3, 389 3, 390 kernel_size=3, 391 stride=1, 392 padding=0, 393 dilation=1, 394 groups=1, 395 bias=True, 396 padding_mode="zeros", 397 ), 398 ) 399 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 400 401 @override_qengines 402 def test_conv2d_nobias_graph_v2(self): 403 # tests the same thing as test_conv2d_nobias_graph, but for version 2 of 404 # ConvPackedParams{n}d 405 module = nn.Sequential( 406 torch.ao.quantization.QuantStub(), 407 nn.Conv2d( 408 3, 409 3, 410 kernel_size=3, 411 stride=1, 412 padding=0, 413 dilation=1, 414 groups=1, 415 bias=False, 416 padding_mode="zeros", 417 ), 418 ) 419 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 420 421 @override_qengines 422 def test_conv2d_graph_v3(self): 423 # tests the same thing as test_conv2d_graph, but for version 3 of 424 # ConvPackedParams{n}d 425 module = nn.Sequential( 426 torch.ao.quantization.QuantStub(), 427 nn.Conv2d( 428 3, 429 3, 430 kernel_size=3, 431 stride=1, 432 padding=0, 433 dilation=1, 434 groups=1, 435 bias=True, 436 padding_mode="zeros", 437 ), 438 ) 439 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 440 441 @override_qengines 442 def test_conv2d_nobias_graph_v3(self): 443 # tests the same thing as test_conv2d_nobias_graph, but for version 3 of 444 # ConvPackedParams{n}d 445 module = nn.Sequential( 446 torch.ao.quantization.QuantStub(), 447 nn.Conv2d( 448 3, 449 3, 450 kernel_size=3, 451 stride=1, 452 padding=0, 453 dilation=1, 454 groups=1, 455 bias=False, 456 padding_mode="zeros", 457 ), 458 ) 459 self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) 460 461 @override_qengines 462 def test_conv2d_relu(self): 463 module = nniq.ConvReLU2d( 464 3, 465 3, 466 kernel_size=3, 467 stride=1, 468 padding=0, 469 dilation=1, 470 groups=1, 471 bias=True, 472 padding_mode="zeros", 473 ) 474 self._test_op(module, input_size=[1, 3, 6, 6], generate=False) 475 # TODO: graph mode quantized conv2d module 476 477 @override_qengines 478 def test_conv3d(self): 479 if qengine_is_fbgemm(): 480 module = nnq.Conv3d( 481 3, 482 3, 483 kernel_size=3, 484 stride=1, 485 padding=0, 486 dilation=1, 487 groups=1, 488 bias=True, 489 padding_mode="zeros", 490 ) 491 self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False) 492 # TODO: graph mode quantized conv3d module 493 494 @override_qengines 495 def test_conv3d_relu(self): 496 if qengine_is_fbgemm(): 497 module = nniq.ConvReLU3d( 498 3, 499 3, 500 kernel_size=3, 501 stride=1, 502 padding=0, 503 dilation=1, 504 groups=1, 505 bias=True, 506 padding_mode="zeros", 507 ) 508 self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False) 509 # TODO: graph mode quantized conv3d module 510 511 @override_qengines 512 @unittest.skipIf( 513 IS_AVX512_VNNI_SUPPORTED, 514 "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098", 515 ) 516 def test_lstm(self): 517 class LSTMModule(torch.nn.Module): 518 def __init__(self) -> None: 519 super().__init__() 520 self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to( 521 dtype=torch.float 522 ) 523 524 def forward(self, x): 525 x = self.lstm(x) 526 return x 527 528 if qengine_is_fbgemm(): 529 mod = LSTMModule() 530 self._test_op( 531 mod, 532 input_size=[4, 4, 3], 533 input_quantized=False, 534 generate=False, 535 new_zipfile_serialization=True, 536 ) 537 538 def test_per_channel_observer(self): 539 obs = PerChannelMinMaxObserver() 540 self._test_obs(obs, input_size=[5, 5], generate=False) 541 542 def test_per_tensor_observer(self): 543 obs = MinMaxObserver() 544 self._test_obs(obs, input_size=[5, 5], generate=False) 545 546 def test_default_qat_qconfig(self): 547 class Model(nn.Module): 548 def __init__(self) -> None: 549 super().__init__() 550 self.linear = nn.Linear(5, 5) 551 self.relu = nn.ReLU() 552 553 def forward(self, x): 554 x = self.linear(x) 555 x = self.relu(x) 556 return x 557 558 model = Model() 559 model.linear.weight = torch.nn.Parameter(torch.randn(5, 5)) 560 model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm") 561 ref_model = torch.ao.quantization.QuantWrapper(model) 562 ref_model = torch.ao.quantization.prepare_qat(ref_model) 563 self._test_obs( 564 ref_model, input_size=[5, 5], generate=False, check_numerics=False 565 ) 566 567 @skipIfNoFBGEMM 568 def test_linear_relu_package_quantization_transforms(self): 569 m = LinearReluFunctional(4).eval() 570 self._test_package(m, input_size=(1, 1, 4, 4), generate=False) 571