1# Owner(s): ["oncall: quantization"] 2 3# torch 4import io 5import itertools 6import unittest 7 8# Standard library 9from typing import List, Tuple 10 11import torch 12import torch.jit 13import torch.jit.quantized 14import torch.nn as nn 15import torch.nn.functional as F 16 17# torch.ao.quantization 18from torch.ao.quantization import ( 19 default_dynamic_qconfig, 20 default_histogram_observer, 21 default_observer, 22 default_per_channel_weight_observer, 23 default_qconfig, 24 default_weight_observer, 25 float16_dynamic_qconfig, 26 fuse_modules, 27 get_default_qconfig, 28 per_channel_dynamic_qconfig, 29 PlaceholderObserver, 30 QConfig, 31 quantize, 32 quantize_dynamic, 33 quantize_dynamic_jit, 34 quantize_jit, 35) 36 37# torch.ao.quantization.quantize_jit 38from torch.ao.quantization.quantize_jit import ( 39 convert_dynamic_jit, 40 convert_jit, 41 fuse_conv_bn_jit, 42 prepare_dynamic_jit, 43 prepare_jit, 44 script_qconfig, 45) 46from torch.jit._recursive import wrap_cpp_module 47from torch.testing import FileCheck 48 49# Annotated models 50from torch.testing._internal.common_quantization import ( 51 AnnotatedConvBnModel, 52 AnnotatedConvModel, 53 AnnotatedConvTransposeModel, 54 AnnotatedNestedModel, 55 AnnotatedSingleLayerLinearModel, 56 AnnotatedSkipQuantModel, 57 ConvBnModel, 58 ConvModel, 59 ConvTransposeModel, 60 default_per_channel_qconfig, 61 get_script_module, 62 NestedModel, 63 QuantizationTestCase, 64 SingleLayerLinearModel, 65 skipIfNoFBGEMM, 66 SkipQuantModel, 67 test_only_eval_fn, 68) 69 70# Testing utils 71from torch.testing._internal.common_quantized import ( 72 override_qengines, 73 qengine_is_fbgemm, 74 qengine_is_qnnpack, 75) 76from torch.testing._internal.common_utils import set_default_dtype 77from torch.testing._internal.jit_utils import ( 78 attrs_with_prefix, 79 get_forward, 80 get_forward_graph, 81) 82 83 84class TestQuantizeJitPasses(QuantizationTestCase): 85 """Test graph mode quantization passes used by quantize_jit""" 86 87 def test_skip_dequant_constant_prop(self): 88 class M(torch.nn.Module): 89 def __init__(self) -> None: 90 super().__init__() 91 self.conv = torch.nn.Conv2d(3, 5, 3).float() 92 93 def forward(self, x): 94 return self.conv(x) 95 96 m = torch.jit.script(M()) 97 observer = default_per_channel_weight_observer.with_args(ch_axis=1) 98 qconfig_dict = {"": QConfig(activation=default_observer, weight=observer)} 99 m = prepare_jit(m, qconfig_dict) 100 data = torch.randn(1, 3, 10, 10, dtype=torch.float) 101 102 m(data) 103 m = convert_jit(m, debug=True) 104 105 freezed = torch.jit.freeze(m) 106 freezed(data) 107 108 # After freezing, weight becomes Constant. 109 # We have this pattern in the original graph: Constant f32_weight -> quant -> dequant 110 # After skipping dequant during Constant Propagation, the resulting graph will be: 111 # Constant int8_weight -> dequant 112 FileCheck().check_count("aten::quantize_per_tensor", 2, exactly=True).run( 113 freezed.graph 114 ) 115 FileCheck().check_count("aten::quantize_per_channel", 0, exactly=True).run( 116 freezed.graph 117 ) 118 FileCheck().check_count("aten::dequantize", 3, exactly=True).run(freezed.graph) 119 FileCheck().check("aten::quantize_per_tensor").check_next( 120 "aten::dequantize" 121 ).check_not("aten::quantize_per_channel").check("aten::dequantize").check_next( 122 "aten::conv2d" 123 ).check_next( 124 "aten::quantize_per_tensor" 125 ).check_next( 126 "aten::dequantize" 127 ).run( 128 freezed.graph 129 ) 130 131 def test_foldbn_trivial(self): 132 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 133 conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 134 135 # Test trivial case 136 class TestModule(torch.nn.Module): 137 def __init__(self, dim): 138 super().__init__() 139 self.conv = conv_module[dim](1, 20, 5, 1) 140 self.bn = bn_module[dim](num_features=20) 141 self.bn.eps = 0.0023 142 143 def forward(self, x): 144 x = self.conv(x) 145 x = self.bn(x) 146 return x 147 148 options = itertools.product([True, False], [2, 3]) 149 data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)} 150 # Check that the transformation doesn't change numerics 151 for tracing, dim in options: 152 eager = TestModule(dim).eval() 153 x = data[dim] 154 scripted_or_traced = get_script_module(eager, tracing, x).eval() 155 # Check that in the original script module's forward we have two 156 # CallMethod nodes. One of them should be for conv.forward and the other 157 # for bn.forward. 158 FileCheck().check_count( 159 'prim::CallMethod[name="forward"]', 2, exactly=True 160 ).run(str(get_forward(scripted_or_traced._c).graph)) 161 162 # Run FoldConvBatchnorm pass. 163 scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) 164 165 # Check that after the pass one of the CallMethods is gone (supposedly, 166 # the bn.forward). 167 FileCheck().check_count( 168 'prim::CallMethod[name="forward"]', 1, exactly=True 169 ).run(str(get_forward_graph(scripted_or_traced._c))) 170 171 # Check that the transformation doesn't change numerics 172 self.assertEqual(eager(x), scripted_or_traced(x)) 173 174 def test_foldbn_trivial_nobias(self): 175 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 176 conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 177 178 # Test trivial case 179 class TestModule(torch.nn.Module): 180 def __init__(self, dim): 181 super().__init__() 182 self.conv = conv_module[dim](1, 20, 5, 1, bias=False) 183 self.bn = bn_module[dim](num_features=20) 184 # to make sure new bias is not zero 185 self.bn.eps = 0.0027 186 self.bn.bias = torch.nn.Parameter(torch.rand([20])) 187 188 def forward(self, x): 189 x = self.conv(x) 190 x = self.bn(x) 191 return x 192 193 options = itertools.product([True, False], [2, 3]) 194 data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)} 195 for tracing, dim in options: 196 eager = TestModule(dim).eval() 197 x = data[dim] 198 scripted_or_traced = get_script_module(eager, tracing, x).eval() 199 # Check that in the original script module's forward we have two 200 # CallMethod nodes. One of them should be for conv.forward and the other 201 # for bn.forward. 202 FileCheck().check_count( 203 'prim::CallMethod[name="forward"]', 2, exactly=True 204 ).run(str(get_forward_graph(scripted_or_traced._c))) 205 206 # Run FoldConvBatchnorm pass. 207 scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) 208 209 # Check that after the pass one of the CallMethods is gone (supposedly, 210 # the bn.forward). 211 FileCheck().check_count( 212 'prim::CallMethod[name="forward"]', 1, exactly=True 213 ).run(str(get_forward_graph(scripted_or_traced._c))) 214 215 # Check that the transformation doesn't change numerics 216 self.assertEqual(eager(x), scripted_or_traced(x)) 217 218 def test_foldbn_in_submodule(self): 219 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 220 conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 221 222 # Test that we find Conv-BN patterns in submodules 223 class SubModule(torch.nn.Module): 224 def __init__(self, dim): 225 super().__init__() 226 self.conv = conv_module[dim](1, 20, 5, 1) 227 self.bn = bn_module[dim](num_features=20) 228 229 def forward(self, x): 230 x = self.conv(x) 231 x = self.bn(x) 232 return x 233 234 class TestModule(torch.nn.Module): 235 def __init__(self, dim): 236 super().__init__() 237 self.sub = SubModule(dim) 238 239 def forward(self, x): 240 x = self.sub(x) 241 return x 242 243 options = itertools.product([True, False], [2, 3]) 244 data = {2: torch.rand(1, 1, 10, 10), 3: torch.rand(1, 1, 10, 10, 10)} 245 for tracing, dim in options: 246 eager = TestModule(dim).eval() 247 x = data[dim] 248 scripted_or_traced = get_script_module(eager, tracing, x).eval() 249 FileCheck().check_count( 250 'prim::CallMethod[name="forward"]', 2, exactly=True 251 ).run(str(get_forward_graph(scripted_or_traced.sub._c))) 252 253 scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) 254 255 FileCheck().check_count( 256 'prim::CallMethod[name="forward"]', 1, exactly=True 257 ).run(str(get_forward_graph(scripted_or_traced.sub._c))) 258 259 self.assertEqual(eager(x), scripted_or_traced(x)) 260 261 def test_foldbn_shared_classtype(self): 262 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 263 conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 264 265 class TestModule(torch.nn.Module): 266 def __init__(self, dim, bias=False): 267 super().__init__() 268 self.conv1 = conv_module[dim](5, 5, 3, bias=bias) 269 self.bn1 = bn_module[dim](num_features=5) 270 self.bn1.running_mean.fill_(-0.2) 271 self.bn1.bias = torch.nn.Parameter(torch.rand([5])) 272 # to make sure new bias is not zero 273 self.bn1.eps = 0.0023 274 self.conv2 = conv_module[dim](5, 5, 3, bias=bias) 275 self.bn2 = bn_module[dim](num_features=5) 276 self.bn2.eps = 0.0029 277 self.relu = torch.nn.ReLU() 278 279 def forward(self, x): 280 x = self.conv1(x) 281 x = self.bn1(x) 282 x = self.relu(x) 283 x = self.conv2(x) 284 x = self.bn2(x) 285 x = self.relu(x) 286 return x 287 288 options = itertools.product([True, False], [2, 2], [True, False]) 289 data = {2: torch.rand(1, 5, 6, 6), 3: torch.rand(1, 5, 6, 6, 6)} 290 for tracing, dim, bias in options: 291 eager = TestModule(dim, bias).eval() 292 x = data[dim] 293 scripted_or_traced = get_script_module(eager, tracing, x) 294 folded = fuse_conv_bn_jit(scripted_or_traced) 295 self.assertEqual(eager(x), scripted_or_traced(x)) 296 297 def test_foldbn_no_fusion(self): 298 """Test that we don't fuse the cases when module type does not match""" 299 300 class CustomConv(torch.nn.Module): 301 def forward(self, x): 302 return x 303 304 class CustomBn(torch.nn.Module): 305 def forward(self, x): 306 return x 307 308 class M(torch.nn.Module): 309 def __init__(self) -> None: 310 super().__init__() 311 self.conv = CustomConv() 312 self.bn = CustomBn() 313 314 def forward(self, x): 315 return self.bn(self.conv(x)) 316 317 m = torch.jit.script(M()) 318 m = fuse_conv_bn_jit(m) 319 FileCheck().check_count("prim::CallMethod", 2, exactly=True).run(m.graph) 320 321 @set_default_dtype(torch.double) 322 def test_foldbn_complex_cases(self): 323 # This test case attempt to try combinations of conv2d/conv3d with bias/nobias 324 # as well as BatchNorm with affine/no-affine along with varying the 325 # number of layers. 326 # this only works when default dtype is double 327 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 328 conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 329 330 class SubModule(torch.nn.Module): 331 def __init__(self, dim, num_blocks, enable_bias, enable_affine): 332 super().__init__() 333 layers = [] 334 for i in range(num_blocks): 335 layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias)) 336 bn_obj = bn_module[dim](num_features=20, affine=enable_affine) 337 if enable_affine: 338 bn_obj.weight = torch.nn.Parameter( 339 torch.rand_like(bn_obj.weight) 340 ) 341 bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias)) 342 bn_obj.running_mean = torch.rand_like(bn_obj.running_mean) 343 bn_obj.running_var = torch.rand_like(bn_obj.running_var) 344 layers.append(bn_obj) 345 self.layers = nn.Sequential(*layers) 346 347 def forward(self, x): 348 return self.layers(x) 349 350 class TestModule(torch.nn.Module): 351 def __init__(self, dim, num_blocks, enable_bias, enable_affine): 352 super().__init__() 353 self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine) 354 355 def forward(self, x): 356 x = self.sub(x) 357 return x 358 359 options = itertools.product( 360 [True, False], [2, 3], [True, False], [True, False], [1, 2] 361 ) 362 data = {2: torch.rand(1, 20, 10, 10), 3: torch.rand(1, 20, 10, 10, 10)} 363 for tracing, dim, enable_bias, enable_bn_affine, num_layers in options: 364 eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval() 365 x = data[dim] 366 scripted_or_traced = get_script_module(eager, tracing, x).eval() 367 368 FileCheck().check_count( 369 'prim::CallMethod[name="forward"]', num_layers * 2, exactly=True 370 ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c))) 371 372 scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) 373 374 FileCheck().check_count( 375 'prim::CallMethod[name="forward"]', num_layers, exactly=True 376 ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c))) 377 378 self.assertEqual(eager(x), scripted_or_traced(x)) 379 380 def test_fuse_linear(self): 381 class FunctionalLinear(torch.nn.Module): 382 def __init__(self, weight, bias): 383 super().__init__() 384 self.weight = weight 385 self.bias = bias 386 387 def forward(self, x): 388 res = torch.matmul(x, self.weight.t()) 389 if self.bias is not None: 390 res.add_(self.bias) 391 return res 392 393 x1 = torch.rand(3) 394 w1 = torch.rand(5, 3) 395 b1 = torch.rand(5) 396 397 x2 = torch.rand(5, 5) 398 w2 = torch.rand(5, 5) 399 b2 = torch.rand(5) 400 401 x3 = torch.rand(5, 5, 5) 402 w3 = torch.rand(5, 5) 403 b3 = torch.rand(5) 404 for has_bias, (x, weight, b) in itertools.product( 405 [True, False], [(x1, w1, b1), (x2, w2, b2), (x3, w3, b3)] 406 ): 407 bias = b if has_bias else None 408 model = torch.jit.trace(FunctionalLinear(weight, bias), [x]) 409 for node in model.graph.nodes(): 410 if node.kind() == "aten::matmul": 411 source_range_1 = node.sourceRange() 412 torch._C._jit_pass_fuse_linear(model.graph) 413 for node in model.graph.nodes(): 414 if node.kind() == "aten::linear": 415 source_range_2 = node.sourceRange() 416 FileCheck().check("aten::linear").run(model.graph) 417 check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] 418 for cn in check_not: 419 FileCheck().check_not(cn).run(model.graph) 420 # make sure it runs 421 self.assertTrue(source_range_1 == source_range_2) 422 model(x) 423 424 # check matmuls are not fused 425 class Matmul(torch.nn.Module): 426 def __init__(self, weight): 427 super().__init__() 428 self.weight = weight 429 430 def forward(self, x): 431 return torch.matmul(x, self.weight) 432 433 x = torch.rand(5, 6, 5) 434 w = torch.rand(5, 5, 100) 435 model = torch.jit.trace(Matmul(w), [x]) 436 torch._C._jit_pass_fuse_linear(model.graph) 437 # check 3d matmul is not fused 438 FileCheck().check("aten::matmul").run(model.graph) 439 FileCheck().check_not("aten::linear").run(model.graph) 440 # make sure it runs 441 model(x) 442 443 def test_insert_observers(self): 444 class M(torch.nn.Module): 445 def __init__(self) -> None: 446 super().__init__() 447 self.conv = torch.nn.Conv2d(3, 5, 3) 448 449 def forward(self, x): 450 return self.conv(x) 451 452 m = torch.jit.script(M()) 453 qconfig_dict = {"": default_qconfig} 454 m = prepare_jit(m, qconfig_dict) 455 # for input and output of conv 456 assert len(attrs_with_prefix(m, "_observer_")) == 2 457 # for weight 458 assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 459 460 def test_insert_observers_interface(self): 461 @torch.jit.interface 462 class SubInterface(torch.nn.Module): 463 def addOne(self, inp) -> torch.Tensor: 464 pass 465 466 class Sub(torch.nn.Module): 467 def __init__(self) -> None: 468 super().__init__() 469 self.fc = torch.nn.Linear(5, 5) 470 471 def addOne(self, inp): 472 return self.fc(inp) + 1 473 474 def forward(self, x): 475 return self.addOne(x) 476 477 class M(torch.nn.Module): 478 def __init__(self) -> None: 479 super().__init__() 480 self.conv = torch.nn.Conv2d(3, 5, 3) 481 self.sub = Sub() 482 483 def forward(self, x): 484 return self.sub(self.conv(x)) 485 486 m = torch.jit.script(M()) 487 qconfig_dict = {"sub.conv": default_qconfig} 488 m = prepare_jit(m, qconfig_dict) 489 490 def test_insert_observers_interface_unshare_type(self): 491 @torch.jit.interface 492 class OperatorIf(nn.Module): 493 def forward(self, inp: torch.Tensor) -> torch.Tensor: 494 pass 495 496 class Operator(nn.Module): 497 def __init__(self, a): 498 super().__init__() 499 self.a = a 500 501 def forward(self, inp: torch.Tensor) -> torch.Tensor: 502 return self.a * (inp + self.a) 503 504 class Inner(nn.Module): 505 op: OperatorIf 506 507 def __init__(self, op): 508 super().__init__() 509 self.op = op 510 511 def forward(self, inp): 512 return self.op(inp) 513 514 class Outer(nn.Module): 515 def __init__(self) -> None: 516 super().__init__() 517 self.inner_a = Inner(Operator(1)) 518 self.inner_b = Inner(Operator(3.0)) 519 520 def forward(self, inp): 521 return self.inner_a(inp) + self.inner_b(inp) 522 523 qconfig_dict = {"inner_a": default_qconfig, "inner_b": default_qconfig} 524 525 eager_model = Outer() 526 for tracing in [True, False]: 527 x = torch.rand(3) 528 script_model = get_script_module(eager_model, tracing, x) 529 # make sure it runs 530 prepare_jit(script_model, qconfig_dict) 531 532 def test_insert_observers_child_qconfig(self): 533 class Sub(torch.nn.Module): 534 def __init__(self) -> None: 535 super().__init__() 536 self.fc = torch.nn.Linear(5, 5) 537 538 def forward(self, x): 539 return self.fc(x) 540 541 class M(torch.nn.Module): 542 def __init__(self) -> None: 543 super().__init__() 544 self.conv = torch.nn.Conv2d(3, 5, 3) 545 self.sub = Sub() 546 547 def forward(self, x): 548 return self.sub(self.conv(x)) 549 550 m = torch.jit.script(M()) 551 qconfig_dict = {"sub.fc": default_qconfig} 552 m = prepare_jit(m, qconfig_dict) 553 # input and output of sub 554 assert len(attrs_with_prefix(m, "_observer_")) == 2 555 # not quantized 556 assert len(attrs_with_prefix(m.conv, "_observer_")) == 0 557 # no observers since we observe in the outer most call site 558 assert len(attrs_with_prefix(m.sub, "_observer_")) == 0 559 # weight of linear 560 assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1 561 562 @unittest.skipUnless( 563 "fbgemm" in torch.backends.quantized.supported_engines, 564 " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" 565 " with instruction set support avx2 or newer.", 566 ) 567 def test_insert_observers_skip_values(self): 568 class ConvFunctionalReLU(torch.nn.Module): 569 def __init__(self) -> None: 570 super().__init__() 571 self.conv = torch.nn.Conv2d(3, 5, 3) 572 573 def forward(self, x): 574 return F.relu(self.conv(x)) 575 576 class ConvReLUModule(torch.nn.Module): 577 def __init__(self) -> None: 578 super().__init__() 579 self.conv = torch.nn.Conv2d(3, 5, 3) 580 self.relu = torch.nn.ReLU() 581 582 def forward(self, x): 583 return self.relu(self.conv(x)) 584 585 class AddReLUModule(torch.nn.Module): 586 def __init__(self) -> None: 587 super().__init__() 588 self.relu = torch.nn.ReLU() 589 self.conv = torch.nn.Conv2d(3, 3, 3).float() 590 591 def forward(self, x): 592 out = self.conv(x) 593 out += x 594 return self.relu(out) 595 596 class AddFunctionalReLU(torch.nn.Module): 597 def __init__(self) -> None: 598 super().__init__() 599 self.conv = torch.nn.Conv2d(3, 3, 3).float() 600 601 def forward(self, x): 602 out = self.conv(x) 603 out += x 604 return F.relu(out) 605 606 def attrs_with_prefix(module, prefix): 607 return [x for x, _ in module._modules._c.items() if x.startswith(prefix)] 608 609 qconfig_dict = {"": default_qconfig} 610 m = torch.jit.script(ConvFunctionalReLU()) 611 m = prepare_jit(m, qconfig_dict) 612 # observer for weight of conv 613 assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 614 # observer for input of conv and output of relu 615 assert len(attrs_with_prefix(m, "_observer_")) == 2 616 617 m = torch.jit.script(ConvReLUModule()) 618 m = prepare_jit(m, qconfig_dict) 619 # observer for input of conv and output of relu 620 assert len(attrs_with_prefix(m, "_observer_")) == 2 621 # observer for weight of conv 622 assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 623 # observer for output of relu 624 assert len(attrs_with_prefix(m.relu, "_observer_")) == 0 625 626 m = torch.jit.script(AddReLUModule()) 627 qconfig_dict = {"": default_qconfig} 628 m = prepare_jit(m, qconfig_dict) 629 assert len(attrs_with_prefix(m, "_observer")) == 3 630 assert len(attrs_with_prefix(m.relu, "_observer")) == 0 631 FileCheck().check("aten::add_").check_not( 632 'Observer = prim::GetAttr[name="_observer_' 633 ).check("ReLU = prim::GetAttr").run(str(get_forward_graph(m._c))) 634 635 m = torch.jit.script(AddFunctionalReLU()) 636 qconfig_dict = {"": default_qconfig} 637 m = prepare_jit(m, qconfig_dict) 638 assert len(attrs_with_prefix(m, "_observer")) == 3 639 FileCheck().check("aten::add_").check_not( 640 'Observer = prim::GetAttr[name="_observer_' 641 ).check("CallFunction").check('Observer = prim::GetAttr[name="_observer_').run( 642 str(get_forward_graph(m._c)) 643 ) 644 645 def test_insert_observers_weight_dtype(self): 646 class M(torch.nn.Module): 647 def __init__(self) -> None: 648 super().__init__() 649 self.conv = torch.nn.Conv2d(3, 5, 3) 650 651 def forward(self, x): 652 return F.relu(self.conv(x)) 653 654 m = torch.jit.script(M()) 655 qconfig_dict = {"": default_qconfig} 656 m = prepare_jit(m, qconfig_dict) 657 activation_dtypes = { 658 obs.getattr("dtype") 659 for x, obs in m._modules._c.items() 660 if x.startswith("_observer_") 661 } 662 weight_dtypes = { 663 obs.getattr("dtype") 664 for x, obs in m.conv._modules._c.items() 665 if x.startswith("_observer_") 666 } 667 assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype" 668 assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype" 669 assert next(iter(activation_dtypes)) != next( 670 iter(weight_dtypes) 671 ), "Expected activation dtype to " 672 " be different from wegiht dtype" 673 674 def test_insert_observers_for_reused_weight(self): 675 class M(torch.nn.Module): 676 def forward(self, x, y, weight): 677 x = F.conv2d(x, weight) 678 y = F.conv2d(y, weight) 679 return x + y 680 681 m = torch.jit.script(M()).eval() 682 m = prepare_jit(m, {"": default_qconfig}) 683 # 3 for x, y, weight, one for output of each F.conv2d and one for output of add 684 assert len(attrs_with_prefix(m, "_observer")) == 6 685 686 def test_insert_observers_shared_class_type(self): 687 class M(torch.nn.Module): 688 def __init__(self) -> None: 689 super().__init__() 690 self.conv1 = torch.nn.Conv2d(3, 5, 3).float() 691 self.conv2 = torch.nn.Conv2d(3, 5, 3).float() 692 693 def forward(self, x): 694 return self.conv2(self.conv1(x)) 695 696 m = torch.jit.script(M()) 697 qconfig_dict = {"": default_qconfig} 698 m = prepare_jit(m, qconfig_dict) 699 # conv1 and conv2 shares the same type, we need to 700 # make sure we didn't quantize the type twice 701 conv1_observers = attrs_with_prefix(m.conv1, "_observer_") 702 conv2_observers = attrs_with_prefix(m.conv2, "_observer_") 703 assert len(conv1_observers) == 1, "Expected to have 1 observer submodules" 704 assert len(conv2_observers) == 1, "Expected to have 1 observer submodules" 705 assert ( 706 conv1_observers == conv2_observers 707 ), "Expect conv1 and conv2 to have same observers since the class type is shared" 708 709 def test_insert_observers_for_general_ops(self): 710 """Make sure we skip observers for ops that doesn't require 711 observation, e.g. flatten 712 """ 713 714 class M(torch.nn.Module): 715 def __init__(self) -> None: 716 super().__init__() 717 self.conv = torch.nn.Conv2d(3, 3, 3).float() 718 719 def forward(self, x): 720 x = self.conv(x) 721 x = torch.flatten(x) 722 return x 723 724 m = torch.jit.script(M()) 725 qconfig_dict = {"": default_qconfig} 726 m = prepare_jit(m, qconfig_dict) 727 # input and output of conv 728 assert len(attrs_with_prefix(m, "_observer_")) == 2 729 FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( 730 'prim::GetAttr[name="conv"]' 731 ).check("prim::CallMethod").check( 732 'Observer = prim::GetAttr[name="_observer_' 733 ).check( 734 "aten::flatten" 735 ).check_not( 736 'Observer = prim::GetAttr[name="_observer_' 737 ).run( 738 m.graph 739 ) 740 741 # TODO: this is too long, split this to test_insert_observers.py and remove 742 # insrt_observers prefix 743 def test_insert_observers_propagate_observed(self): 744 """Make sure we propagate observed property through general ops""" 745 746 class M(torch.nn.Module): 747 def __init__(self) -> None: 748 super().__init__() 749 self.conv1 = torch.nn.Conv2d(3, 3, 3).float() 750 self.conv2 = torch.nn.Conv2d(3, 3, 3).float() 751 752 def forward(self, x): 753 x = self.conv1(x) 754 x = torch.flatten(x) 755 # we don't want to insert observer for input of self.conv2 756 # because output of self.conv1 is already observed 757 x = self.conv2(x) 758 return x 759 760 m = torch.jit.script(M()) 761 qconfig_dict = {"": default_qconfig} 762 m = prepare_jit(m, qconfig_dict) 763 # input and output of conv 764 assert len(attrs_with_prefix(m, "_observer_")) == 3 765 FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( 766 'prim::GetAttr[name="conv1"]' 767 ).check("prim::CallMethod").check( 768 'Observer = prim::GetAttr[name="_observer_' 769 ).check( 770 "aten::flatten" 771 ).check_not( 772 'Observer = prim::GetAttr[name="_observer_' 773 ).check( 774 'prim::GetAttr[name="conv2"]' 775 ).check( 776 'Observer = prim::GetAttr[name="_observer_' 777 ).run( 778 m.graph 779 ) 780 781 def test_insert_observers_propagate_observed_in_submodule(self): 782 """Make sure we propagate observed property through general ops""" 783 784 class M(torch.nn.Module): 785 def __init__(self) -> None: 786 super().__init__() 787 self.conv1 = torch.nn.Conv2d(3, 3, 3).float() 788 self.conv2 = torch.nn.Conv2d(3, 3, 3).float() 789 self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 790 791 def forward(self, x): 792 x = self.conv1(x) 793 x = self.avgpool(x) 794 # we don't want to insert observer for input of self.conv2 795 # because output of self.conv1 is already observed 796 x = self.conv2(x) 797 return x 798 799 m = torch.jit.script(M()) 800 qconfig_dict = {"": default_qconfig} 801 m = prepare_jit(m, qconfig_dict) 802 # input and output of conv 803 assert len(attrs_with_prefix(m, "_observer_")) == 3 804 FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( 805 'prim::GetAttr[name="conv1"]' 806 ).check("prim::CallMethod").check( 807 'Observer = prim::GetAttr[name="_observer_' 808 ).check( 809 "prim::CallMethod" 810 ).check_not( 811 'Observer = prim::GetAttr[name="_observer_' 812 ).check( 813 'prim::GetAttr[name="conv2"]' 814 ).check( 815 'Observer = prim::GetAttr[name="_observer_' 816 ).run( 817 m.graph 818 ) 819 820 def test_insert_observers_propagate_observed_for_function(self): 821 def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor: 822 batchsize, num_channels, height, width = x.data.size() 823 channels_per_group = num_channels // groups 824 # reshape 825 x = x.view(batchsize, groups, channels_per_group, height, width) 826 x = torch.transpose(x, 1, 2).contiguous() 827 # flatten 828 x = x.view(batchsize, -1, height, width) 829 return x 830 831 class M(torch.nn.Module): 832 def __init__(self) -> None: 833 super().__init__() 834 self.conv1 = torch.nn.Conv2d(3, 3, 1).float() 835 self.conv2 = torch.nn.Conv2d(3, 3, 1).float() 836 837 def forward(self, x): 838 x = self.conv1(x) 839 x = channel_shuffle(x, 1) 840 x = self.conv2(x) 841 return x 842 843 data = [ 844 ( 845 torch.rand((1, 3, 10, 10), dtype=torch.float), 846 torch.randint(0, 1, (1,), dtype=torch.long), 847 ) 848 for _ in range(2) 849 ] 850 m = torch.jit.script(M()).eval() 851 m = prepare_jit(m, {"": default_qconfig}) 852 # we want to test that channel_shuffle is going to pass 853 # the observed property from the output of conv1 to input of conv2 854 # so that we don't insert observers for input of conv2 855 assert ( 856 len( 857 attrs_with_prefix( 858 m, 859 "_observer_", 860 ) 861 ) 862 == 3 863 ) 864 865 def test_insert_observers_for_if(self): 866 class QuantProp(torch.nn.Module): 867 def __init__(self, use_skip): 868 super().__init__() 869 self.conv = torch.nn.Conv2d(3, 3, 1).float() 870 self.use_skip = use_skip 871 872 def forward(self, x): 873 if self.use_skip: 874 x = self.conv(x) 875 return torch.reshape(x, x.shape) 876 else: 877 x = self.conv(x) 878 return torch.reshape(x, x.shape) 879 880 class Res(torch.nn.Module): 881 def __init__(self, use_skip): 882 super().__init__() 883 self.conv = torch.nn.Conv2d(3, 3, 1).float() 884 self.use_skip = use_skip 885 886 def forward(self, x): 887 if self.use_skip: 888 return self.conv(x) 889 else: 890 return self.conv(x) 891 892 class M(torch.nn.Module): 893 def __init__(self) -> None: 894 super().__init__() 895 self.quant_prop = QuantProp(True) 896 self.res = Res(False) 897 898 def forward(self, x): 899 x = self.quant_prop(x) 900 x = self.res(x) 901 return x 902 903 data = [torch.rand(1, 3, 10, 10, dtype=torch.float)] 904 result = {False: [1, 2, 2], True: [2, 1, 0]} 905 for tracing in [True, False]: 906 if tracing: 907 m = torch.jit.trace(M(), data).eval() 908 else: 909 m = torch.jit.script(M()).eval() 910 m = prepare_jit(m, {"": default_qconfig}) 911 assert ( 912 len( 913 attrs_with_prefix( 914 m, 915 "_observer_", 916 ) 917 ) 918 == result[tracing][0] 919 ) 920 assert ( 921 len( 922 attrs_with_prefix( 923 m.quant_prop, 924 "_observer_", 925 ) 926 ) 927 == result[tracing][1] 928 ) 929 assert ( 930 len( 931 attrs_with_prefix( 932 m.res, 933 "_observer_", 934 ) 935 ) 936 == result[tracing][2] 937 ) 938 939 def test_insert_observers_for_nested_if(self): 940 class Res(torch.nn.Module): 941 def __init__(self, use_skip): 942 super().__init__() 943 self.conv = torch.nn.Conv2d(3, 3, 1).float() 944 self.cond = use_skip 945 self.use_skip = use_skip 946 947 def forward(self, x): 948 if self.use_skip: 949 if self.cond: 950 return self.conv(x) 951 else: 952 return self.conv(x) 953 else: 954 return self.conv(x) 955 956 class M(torch.nn.Module): 957 def __init__(self) -> None: 958 super().__init__() 959 self.res1 = Res(True) 960 self.res2 = Res(False) 961 962 def forward(self, x): 963 x = self.res1(x) 964 x = self.res2(x) 965 return x 966 967 data = torch.rand((1, 3, 10, 10), dtype=torch.float) 968 result = {True: 3, False: 1} 969 for tracing in [True, False]: 970 if tracing: 971 m = torch.jit.trace(M(), data).eval() 972 else: 973 m = torch.jit.script(M()).eval() 974 m = prepare_jit(m, {"": default_qconfig}) 975 assert len(attrs_with_prefix(m, "_observer_")) == result[tracing] 976 977 def test_insert_observers_for_if_consistent_observation(self): 978 """check quantization for if works as long as 979 output of all branches are quantized/observed consistently 980 """ 981 982 class M(torch.nn.Module): 983 def __init__(self, cond): 984 super().__init__() 985 self.conv = torch.nn.Conv2d(3, 3, 3).float() 986 self.cond = cond 987 988 def forward(self, x): 989 x = self.conv(x) 990 # x is already observed 991 if self.cond: 992 x = torch.flatten(x) 993 return x 994 995 class M2(torch.nn.Module): 996 def __init__(self, cond): 997 super().__init__() 998 self.conv1 = torch.nn.Conv2d(3, 3, 3).float() 999 self.conv2 = torch.nn.Conv2d(3, 3, 3).float() 1000 self.cond = cond 1001 1002 def forward(self, x): 1003 x = self.conv1(x) 1004 if self.cond: 1005 x = self.conv2(x) 1006 # x will be observed in the branch 1007 else: 1008 x = torch.flatten(x) 1009 # since output for both branch are quantized 1010 # the if node is quantized consistently 1011 return x 1012 1013 data = torch.rand((1, 3, 5, 5), dtype=torch.float) 1014 options = list(itertools.product([True, False], [True, False])) 1015 for cond, tracing in options: 1016 if tracing: 1017 m = torch.jit.trace(M(cond), data) 1018 else: 1019 m = torch.jit.script(M(cond)) 1020 m = prepare_jit(m, {"": default_qconfig}) 1021 assert len(attrs_with_prefix(m, "_observer_")) == 2 1022 1023 for cond, tracing in options: 1024 if tracing: 1025 m = torch.jit.trace(M2(cond), data) 1026 else: 1027 m = torch.jit.script(M2(cond)) 1028 m = prepare_jit(m, {"": default_qconfig}) 1029 num_observers = 2 if tracing and not cond else 3 1030 assert len(attrs_with_prefix(m, "_observer_")) == num_observers 1031 1032 def test_insert_quant_dequant(self): 1033 class M(torch.nn.Module): 1034 def __init__(self) -> None: 1035 super().__init__() 1036 self.conv = torch.nn.Conv2d(3, 5, 3).float() 1037 1038 def forward(self, x): 1039 return self.conv(x) 1040 1041 for is_per_channel in [True, False]: 1042 m = torch.jit.script(M()) 1043 observer = ( 1044 default_per_channel_weight_observer.with_args(ch_axis=1) 1045 if is_per_channel 1046 else default_observer 1047 ) 1048 qconfig_dict = {"": QConfig(activation=observer, weight=observer)} 1049 m = prepare_jit(m, qconfig_dict) 1050 data = torch.randn(1, 3, 10, 10, dtype=torch.float) 1051 1052 m(data) 1053 m = convert_jit(m, debug=True) 1054 assert ( 1055 len(m._modules._c.items()) == 1 1056 ), "Expected to have single submodule of conv" 1057 # make sure the quantized model is executable 1058 m(data) 1059 quant_func = ( 1060 "aten::quantize_per_channel" 1061 if is_per_channel 1062 else "aten::quantize_per_tensor" 1063 ) 1064 FileCheck().check_count(quant_func, 3, exactly=True).run(m.graph) 1065 1066 def test_insert_quant_dequant_shared_class_type(self): 1067 class M(torch.nn.Module): 1068 def __init__(self) -> None: 1069 super().__init__() 1070 self.conv1 = torch.nn.Conv2d(3, 3, 3).float() 1071 self.conv2 = torch.nn.Conv2d(3, 3, 3).float() 1072 1073 def forward(self, x): 1074 return self.conv2(self.conv1(x)) 1075 1076 for is_per_channel in [True, False]: 1077 m = torch.jit.script(M()) 1078 observer = ( 1079 default_per_channel_weight_observer.with_args(ch_axis=1) 1080 if is_per_channel 1081 else default_observer 1082 ) 1083 qconfig = QConfig(activation=observer, weight=observer) 1084 qconfig_dict = {"": qconfig} 1085 m = prepare_jit(m, qconfig_dict) 1086 # observers for input, output and value between conv1/conv2 1087 assert ( 1088 len(attrs_with_prefix(m, "_observer_")) == 3 1089 ), "Expected to have 3 obervers" 1090 # observer for weight 1091 assert ( 1092 len(attrs_with_prefix(m.conv1, "_observer_")) == 1 1093 ), "Expected to have 1 obervers" 1094 # observer for weight 1095 assert ( 1096 len(attrs_with_prefix(m.conv2, "_observer_")) == 1 1097 ), "Expected to have 1 obervers" 1098 1099 data = torch.randn(1, 3, 10, 10, dtype=torch.float) 1100 m(data) 1101 m = convert_jit(m, debug=True) 1102 m(data) 1103 assert m.conv1._c._type() == m.conv2._c._type() 1104 1105 # check all observers have been removed 1106 assert ( 1107 len(attrs_with_prefix(m, "_observer_")) == 0 1108 ), "Expected to have 0 obervers" 1109 assert ( 1110 len(attrs_with_prefix(m.conv1, "_observer_")) == 0 1111 ), "Expected to have 0 obervers" 1112 assert ( 1113 len(attrs_with_prefix(m.conv2, "_observer_")) == 0 1114 ), "Expected to have 0 obervers" 1115 1116 quant_func = ( 1117 "aten::quantize_per_channel" 1118 if is_per_channel 1119 else "aten::quantize_per_tensor" 1120 ) 1121 for module in ["conv1", "conv2"]: 1122 conv = m._c.getattr(module) 1123 # quantize weight 1124 FileCheck().check(quant_func).check_next("aten::dequantize").check( 1125 'prim::CallMethod[name="_conv_forward"]' 1126 ).check("return").run(get_forward_graph(conv)) 1127 # no quantize node in _conv_forward 1128 FileCheck().check_not(quant_func).check("aten::conv2d").check_not( 1129 quant_func 1130 ).check("return").run(conv._get_method("_conv_forward").graph) 1131 1132 def test_dedup_module_uses(self): 1133 class M(torch.nn.Module): 1134 def __init__(self) -> None: 1135 super().__init__() 1136 self.relu = torch.nn.ReLU() 1137 1138 def forward(self, x): 1139 x = self.relu(x) 1140 x -= 0.5 1141 return self.relu(x) 1142 1143 data = torch.randn((2, 2)) 1144 m = torch.jit.script(M()) 1145 ref_res = m(data) 1146 assert ( 1147 len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 1 1148 ), "Expected to have 1 relu modules after dedup module uses" 1149 torch._C._jit_pass_dedup_module_uses(m._c) 1150 m = torch.jit._recursive.wrap_cpp_module(m._c) 1151 res = m(data) 1152 assert ( 1153 len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 2 1154 ), "Expected to have 2 relu modules after dedup module uses" 1155 self.assertEqual(res, ref_res) 1156 1157 def test_replicate_dequantize(self): 1158 class M(torch.nn.Module): 1159 def __init__(self) -> None: 1160 super().__init__() 1161 self.conv = torch.nn.Conv2d(3, 3, 1).float() 1162 1163 def forward(self, x): 1164 x = torch.dequantize(x) 1165 r = self.conv(x) 1166 r += x 1167 return r 1168 1169 x = torch.randn([1, 3, 10, 10], dtype=torch.float) 1170 x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8) 1171 m = torch.jit.script(M()) 1172 ref_res = m(x) 1173 FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) 1174 torch._C._jit_pass_replicate_dequantize(m.graph) 1175 FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph) 1176 res = get_forward(m._c)(x) 1177 self.assertEqual(res, ref_res) 1178 1179 def test_replicate_dequantize_in_block(self): 1180 class M(torch.nn.Module): 1181 def __init__(self, cond): 1182 super().__init__() 1183 self.conv = torch.nn.Conv2d(3, 3, 1).float() 1184 1185 self.cond = cond 1186 1187 def forward(self, x): 1188 x = torch.dequantize(x) 1189 if self.cond: 1190 x = self.conv(x) 1191 else: 1192 x = x + 3 1193 return x 1194 1195 x = torch.randn([1, 3, 10, 10], dtype=torch.float) 1196 x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8) 1197 m = torch.jit.script(M(True)) 1198 ref_res = m(x) 1199 FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) 1200 torch._C._jit_pass_replicate_dequantize(m.graph) 1201 FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph) 1202 # check dequantize is right before CallMethod of conv 1203 FileCheck().check("aten::dequantize").check_next("CallMethod").run(m.graph) 1204 # check dequantize is right before add 1205 FileCheck().check("aten::dequantize").check("aten::dequantize").check_next( 1206 "aten::add" 1207 ).run(m.graph) 1208 res = get_forward(m._c)(x) 1209 self.assertEqual(res, ref_res) 1210 1211 def test_swap_functional_linear(self): 1212 # TODO: This pass replaces any function called "linear" with "aten::linear" 1213 # No longer necessary, and also quite surprising 1214 def linear(input, weight, bias): 1215 return torch.nn.functional.linear(input, weight, bias) 1216 1217 class M(torch.nn.Module): 1218 def forward(self, x, weight, bias): 1219 x = torch.dequantize(x) 1220 weight = torch.dequantize(weight) 1221 x = linear(x, weight, bias) 1222 x = torch.quantize_per_tensor( 1223 x, scale=1.0, zero_point=0, dtype=torch.quint8 1224 ) 1225 return x 1226 1227 x = torch.rand((10, 5), dtype=torch.float) 1228 x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8) 1229 weight = torch.rand((5, 5), dtype=torch.float) 1230 weight = torch.quantize_per_tensor( 1231 weight, scale=0.5, zero_point=1, dtype=torch.qint8 1232 ) 1233 bias = torch.rand((5), dtype=torch.float) 1234 m = torch.jit.script(M()) 1235 ref_res = m(x, weight, bias) 1236 FileCheck().check("CallFunction").run(m.graph) 1237 torch._C._jit_pass_swap_functional_linear(m.graph) 1238 FileCheck().check("aten::linear").check_not("CallFunction").run(m.graph) 1239 res = m(x, weight, bias) 1240 self.assertEqual(res, ref_res) 1241 1242 def test_replicate_quantize_for_if(self): 1243 """We want to move quantize nodes for output of prim::If 1244 inside the prim::If blocks so that we can match quantization 1245 patterns. 1246 """ 1247 1248 class Res(torch.nn.Module): 1249 def __init__(self) -> None: 1250 super().__init__() 1251 self.conv = torch.nn.Conv2d(3, 3, 1).float() 1252 self.conv2 = torch.nn.Conv2d(3, 3, 1).float() 1253 self.use_skip = True 1254 1255 def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: 1256 # to avoid being frozen 1257 self.use_skip = cond 1258 if self.use_skip: 1259 return self.conv(x) 1260 else: 1261 return self.conv2(x) 1262 1263 class M(torch.nn.Module): 1264 def __init__(self) -> None: 1265 super().__init__() 1266 self.res1 = Res() 1267 self.res2 = Res() 1268 1269 def forward(self, x): 1270 x = self.res1(x, True) 1271 x = self.res2(x, False) 1272 return x 1273 1274 data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] 1275 qconfig_dict = {"": default_qconfig} 1276 m = torch.jit.script(M()).eval() 1277 m = quantize_jit(m, qconfig_dict, test_only_eval_fn, [data]) 1278 # make sure patterns in both branches are fused 1279 FileCheck().check_count("quantized::conv2d(", 4, exactly=True).run(m.graph) 1280 1281 def test_finalize_for_linear(self): 1282 class M(torch.nn.Module): 1283 def __init__(self) -> None: 1284 super().__init__() 1285 self.fc = torch.nn.Linear(5, 5).float() 1286 1287 def forward(self, x): 1288 return self.fc(x) 1289 1290 data = [[torch.rand((1, 5), dtype=torch.float)]] 1291 qconfig_dict = {"": default_qconfig} 1292 model = torch.jit.script(M()).eval() 1293 model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data]) 1294 # make sure there is only one quantize_per_tensor for input 1295 # and linear_prepack is folded 1296 FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).check_not( 1297 "quantized::linear_prepack" 1298 ).check("quantized::linear").run(model.graph) 1299 1300 def test_inplace_option(self): 1301 for tracing in [True, False]: 1302 model = get_script_module( 1303 torch.nn.Conv2d(3, 3, 3).float(), tracing, self.img_data_2d[0][0] 1304 ) 1305 qconfig_dict = {"": default_qconfig} 1306 quantize_jit( 1307 model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True 1308 ) 1309 FileCheck().check("quantized::conv2d").run(model.graph) 1310 1311 FileCheck().check_not("aten::conv2d").run(model.graph) 1312 1313 def test_finalize_debug(self): 1314 class M(torch.nn.Module): 1315 def __init__(self) -> None: 1316 super().__init__() 1317 self.conv = torch.nn.Conv2d(3, 3, 3).float() 1318 self.avgpool = torch.nn.AvgPool2d(3) 1319 1320 def forward(self, x): 1321 x = self.conv(x) 1322 x = self.avgpool(x) 1323 return x 1324 1325 data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] 1326 qconfig_dict = {"": default_qconfig} 1327 model = torch.jit.script(M()).eval() 1328 model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data], debug=True) 1329 FileCheck().check_not("quantized::conv2d").check("aten::conv2d").check( 1330 "aten::avg_pool2d" 1331 ).check("aten::q_scale").check_next("aten::q_zero_point").check_next( 1332 "prim::dtype" 1333 ).check_next( 1334 "aten::quantize_per_tensor" 1335 ).check( 1336 "aten::dequantize" 1337 ).run( 1338 model.graph 1339 ) 1340 1341 def test_module_list(self): 1342 class SimpleLinearLayer(torch.nn.Module): 1343 def __init__(self) -> None: 1344 super().__init__() 1345 self.fc = torch.nn.Linear(5, 5).float() 1346 1347 def forward(self, x): 1348 return self.fc(x) 1349 1350 class ComplexModel(torch.nn.Module): 1351 def __init__(self) -> None: 1352 super().__init__() 1353 self.layers = torch.nn.ModuleList( 1354 [SimpleLinearLayer() for i in range(2)] 1355 ) 1356 1357 def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 1358 states = [] 1359 for layer in self.layers: 1360 val = layer(x) 1361 states.append(val) 1362 return states 1363 1364 data = torch.rand((1, 5), dtype=torch.float) 1365 qconfig_dict = {"": default_qconfig} 1366 model = torch.jit.script(ComplexModel()).eval() 1367 model = prepare_jit(model, qconfig_dict) 1368 assert len(attrs_with_prefix(model, "_observer")) == 3 1369 model(data) 1370 model = convert_jit(model, debug=False) 1371 FileCheck().check("quantized::linear").check("quantized::linear").run( 1372 model.graph 1373 ) 1374 1375 def test_conv_trace(self): 1376 class M(torch.nn.Module): 1377 def __init__(self) -> None: 1378 super().__init__() 1379 self.conv1d = torch.nn.Conv1d(3, 3, 3).float() 1380 self.conv2d = torch.nn.Conv2d(3, 3, 3).float() 1381 self.conv3d = torch.nn.Conv3d(3, 3, 3).float() 1382 1383 def forward(self, x, y, z): 1384 a = self.conv1d(x) 1385 b = self.conv2d(y) 1386 c = self.conv3d(z) 1387 return (a, b, c) 1388 1389 qconfig_dict = {"": default_qconfig} 1390 inputs = ( 1391 torch.rand((1, 3, 10), dtype=torch.float), 1392 torch.rand((1, 3, 10, 10), dtype=torch.float), 1393 torch.rand((1, 3, 10, 10, 10), dtype=torch.float), 1394 ) 1395 model = torch.jit.trace(M(), inputs).eval() 1396 m = prepare_jit(model, qconfig_dict) 1397 FileCheck().check("aten::conv1d").check_not("aten::_convolution").run( 1398 str(get_forward_graph(m.conv1d._c)) 1399 ) 1400 FileCheck().check("aten::conv2d").check_not("aten::_convolution").run( 1401 str(get_forward_graph(m.conv2d._c)) 1402 ) 1403 FileCheck().check("aten::conv3d").check_not("aten::_convolution").run( 1404 str(get_forward_graph(m.conv3d._c)) 1405 ) 1406 1407 def test_convtranspose_trace(self): 1408 class M(torch.nn.Module): 1409 def __init__(self) -> None: 1410 super().__init__() 1411 self.convtranspose1d = torch.nn.ConvTranspose1d(3, 3, 3).float() 1412 self.convtranspose2d = torch.nn.ConvTranspose2d(3, 3, 3).float() 1413 self.convtranspose3d = torch.nn.ConvTranspose3d(3, 3, 3).float() 1414 1415 def forward(self, x, y, z): 1416 a = self.convtranspose1d(x) 1417 b = self.convtranspose2d(y) 1418 c = self.convtranspose3d(z) 1419 return (a, b, c) 1420 1421 qconfig_dict = {"": default_qconfig} 1422 inputs = ( 1423 torch.rand((1, 3, 10), dtype=torch.float), 1424 torch.rand((1, 3, 10, 10), dtype=torch.float), 1425 torch.rand((1, 3, 10, 10, 10), dtype=torch.float), 1426 ) 1427 model = torch.jit.trace(M(), inputs).eval() 1428 m = prepare_jit(model, qconfig_dict) 1429 FileCheck().check("aten::conv_transpose1d").check_not("aten::_convolution").run( 1430 str(get_forward_graph(m.convtranspose1d._c)) 1431 ) 1432 FileCheck().check("aten::conv_transpose2d").check_not("aten::_convolution").run( 1433 str(get_forward_graph(m.convtranspose2d._c)) 1434 ) 1435 FileCheck().check("aten::conv_transpose3d").check_not("aten::_convolution").run( 1436 str(get_forward_graph(m.convtranspose3d._c)) 1437 ) 1438 1439 @unittest.skipUnless( 1440 "fbgemm" in torch.backends.quantized.supported_engines, 1441 " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" 1442 " with instruction set support avx2 or newer.", 1443 ) 1444 def test_replicate_dequant_same_value(self): 1445 class Mul(torch.nn.Module): 1446 def __init__(self) -> None: 1447 super().__init__() 1448 self.conv = torch.nn.Conv2d(3, 3, 3).float() 1449 1450 def forward(self, x): 1451 x = self.conv(x) 1452 return x * x 1453 1454 data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] 1455 qconfig_dict = {"": default_qconfig} 1456 model = torch.jit.script(Mul()).eval() 1457 m = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data]) 1458 FileCheck().check("quantized::mul(").check_not("aten::mul").run(m.graph) 1459 1460 def test_interface_with_fork(self): 1461 class SubModule(torch.nn.Module): 1462 def __init__(self) -> None: 1463 super().__init__() 1464 self.embedding1 = torch.nn.EmbeddingBag( 1465 num_embeddings=10, 1466 embedding_dim=12, 1467 include_last_offset=True, 1468 sparse=False, 1469 mode="sum", 1470 ) 1471 1472 def forward(self, x, y): 1473 return self.embedding1(x, y) 1474 1475 class OrigMod(torch.nn.Module): 1476 def __init__(self) -> None: 1477 super().__init__() 1478 self.embedding1 = torch.nn.EmbeddingBag( 1479 num_embeddings=10, 1480 embedding_dim=12, 1481 include_last_offset=True, 1482 sparse=False, 1483 mode="sum", 1484 ) 1485 1486 def forward(self, x, y): 1487 return self.embedding1(x, y) 1488 1489 @torch.jit.interface 1490 class ModInterface(torch.nn.Module): 1491 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1492 pass 1493 1494 class TestModule(torch.nn.Module): 1495 proxy_mod: ModInterface 1496 1497 def __init__(self) -> None: 1498 super().__init__() 1499 self.proxy_mod = OrigMod() 1500 self.sub = SubModule() 1501 1502 def forward(self, x, y): 1503 a = self.proxy_mod(x, y) 1504 b = self.sub(x, y) 1505 return b 1506 1507 class MainModule(torch.nn.Module): 1508 def __init__(self) -> None: 1509 super().__init__() 1510 self.test = TestModule() 1511 1512 def forward(self, x, y): 1513 fut = torch.jit._fork(self.test.forward, x, y) 1514 z = torch.jit._wait(fut) 1515 return z 1516 1517 indices = torch.tensor( 1518 [ 1519 9, 1520 6, 1521 5, 1522 7, 1523 8, 1524 8, 1525 9, 1526 2, 1527 8, 1528 6, 1529 6, 1530 9, 1531 1, 1532 6, 1533 8, 1534 8, 1535 3, 1536 2, 1537 3, 1538 6, 1539 3, 1540 6, 1541 5, 1542 7, 1543 0, 1544 8, 1545 4, 1546 6, 1547 5, 1548 8, 1549 2, 1550 3, 1551 ] 1552 ) 1553 offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 1554 m = torch.jit.trace(MainModule(), (indices, offsets)) 1555 m.eval() 1556 1557 int8_qconfig = QConfig( 1558 activation=PlaceholderObserver.with_args( 1559 dtype=torch.float, custom_op_name="embedding_bag_byte" 1560 ), 1561 weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"), 1562 ) 1563 1564 m = prepare_jit(m, {"": int8_qconfig}) 1565 m = convert_jit(m) 1566 FileCheck().check("quantized::embedding_bag_byte_rowwise_offsets").run(m.graph) 1567 1568 @skipIfNoFBGEMM 1569 def test_quantize_fork_wait(self): 1570 """Tests the case where fork and wait calls are in different subgraphs 1571 Calling inline fork-wait only removes the fork call and leaves aten::wait 1572 calls in the graph, with Tensor as input (instead of Future[Tensor]) 1573 """ 1574 1575 class MainModule(nn.Module): 1576 def __init__(self) -> None: 1577 super().__init__() 1578 self.fork_ops = ForkModule() 1579 1580 def init_values(self, x): 1581 shared_module = self.fork_ops(x) 1582 self.fork_dict = shared_module 1583 1584 def forward(self, x): 1585 val = torch.jit._wait(self.fork_ops(x)) 1586 return val 1587 1588 class TestModule(torch.nn.Module): 1589 def forward(self, x): 1590 w = torch.ones(5, 5) 1591 b = torch.zeros(5) 1592 return torch.nn.functional.linear(x, w, b) 1593 1594 class ForkModule(nn.Module): 1595 def __init__(self) -> None: 1596 super().__init__() 1597 self.test = TestModule() 1598 1599 def forward(self, x): 1600 fut = torch.jit._fork(self.test.forward, x) 1601 return fut 1602 1603 model = MainModule().eval() 1604 traced = torch.jit.trace(model, (torch.randn(5, 5),)) 1605 model = prepare_dynamic_jit(traced, {"": default_qconfig}) 1606 model = convert_dynamic_jit(model) 1607 FileCheck().check("quantized::linear_dynamic").run(model.graph) 1608 # Make sure model save works 1609 b = io.BytesIO() 1610 torch.jit.save(model, b) 1611 1612 1613class TestQuantizeJitOps(QuantizationTestCase): 1614 """Test graph mode post training static quantization works 1615 for individual ops end to end. 1616 """ 1617 1618 @skipIfNoFBGEMM 1619 def test_linear(self): 1620 class ModuleLinear(torch.nn.Module): 1621 def __init__(self, has_relu=False, f_relu=False): 1622 super().__init__() 1623 self.linear = torch.nn.Linear(30, 4).float() 1624 if has_relu: 1625 if f_relu: 1626 self.relu = F.relu 1627 else: 1628 self.relu = torch.nn.ReLU() 1629 else: 1630 self.relu = torch.nn.Identity() 1631 1632 def forward(self, x): 1633 return self.relu(self.linear(x)) 1634 1635 class FuncLinear(torch.nn.Module): 1636 def __init__(self, has_relu=False, f_relu=False): 1637 super().__init__() 1638 self.w = torch.randn(4, 30) 1639 self.b = torch.randn(4) 1640 if has_relu: 1641 if f_relu: 1642 self.relu = F.relu 1643 else: 1644 self.relu = torch.nn.ReLU() 1645 else: 1646 self.relu = torch.nn.Identity() 1647 1648 def forward(self, x): 1649 return self.relu(F.linear(x, self.w, self.b)) 1650 1651 data = [[torch.rand((1, 30), dtype=torch.float)]] 1652 for model, tracing in itertools.product( 1653 [ModuleLinear(has_relu=False), FuncLinear(has_relu=False)], [True, False] 1654 ): 1655 model = self.checkGraphModeOp(model, data, "quantized::linear", tracing) 1656 FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( 1657 model.graph 1658 ) 1659 FileCheck().check_not("quantized::linear_prepack").run(model.graph) 1660 1661 for f_relu, tracing in itertools.product([True, False], [True, False]): 1662 for model in [ 1663 ModuleLinear(has_relu=True, f_relu=f_relu), 1664 FuncLinear(has_relu=True, f_relu=f_relu), 1665 ]: 1666 model = self.checkGraphModeOp( 1667 model, data, "quantized::linear_relu", tracing 1668 ) 1669 checker = ( 1670 FileCheck() 1671 .check_not("aten::linear") 1672 .check_not("aten::relu") 1673 .check_not("quantized::linear(") 1674 .check_not("quantized::relu(") 1675 .run(model.graph) 1676 ) 1677 1678 @skipIfNoFBGEMM 1679 def test_quantized_conv(self): 1680 conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 1681 1682 class Conv(torch.nn.Module): 1683 def __init__(self, dim): 1684 super().__init__() 1685 self.conv = conv_module[dim](3, 3, 3).float() 1686 1687 def forward(self, x): 1688 return self.conv(x) 1689 1690 options = itertools.product([1, 2, 3], [True, False]) 1691 for dim, tracing in options: 1692 model = self.checkGraphModeOp( 1693 Conv(dim), 1694 self.img_data_dict[dim], 1695 f"quantized::conv{dim}d", 1696 tracing, 1697 ) 1698 # make sure there is only one quantize_per_tensor for input 1699 # and conv2d_prepack is folded 1700 FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( 1701 model.graph 1702 ) 1703 1704 FileCheck().check_not(f"quantized::conv{dim}d_prepack").run(model.graph) 1705 1706 @skipIfNoFBGEMM 1707 def test_quantized_conv_relu(self): 1708 """tests for conv1d_relu/conv2d_relu/conv3d_relu""" 1709 conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 1710 1711 class ConvNdRelu(torch.nn.Module): 1712 def __init__(self, dim, inplace): 1713 super().__init__() 1714 self.conv = conv_module[dim](3, 3, 3).float() 1715 self.relu = torch.nn.ReLU(inplace) 1716 1717 def forward(self, x): 1718 return self.relu(self.conv(x)) 1719 1720 class ConvNdFunctionalRelu(torch.nn.Module): 1721 def __init__(self, dim): 1722 super().__init__() 1723 self.conv = conv_module[dim](3, 3, 3).float() 1724 1725 def forward(self, x): 1726 return F.relu(self.conv(x)) 1727 1728 class ConvNdInplaceFunctionalRelu(torch.nn.Module): 1729 def __init__(self, dim): 1730 super().__init__() 1731 self.conv = conv_module[dim](3, 3, 3).float() 1732 1733 def forward(self, x): 1734 return F.relu(self.conv(x), True) 1735 1736 options = itertools.product([1, 2, 3], [True, False]) 1737 for dim, tracing in options: 1738 for orig_m in [ 1739 ConvNdRelu(dim, True), 1740 ConvNdRelu(dim, False), 1741 ConvNdFunctionalRelu(dim), 1742 ConvNdInplaceFunctionalRelu(dim), 1743 ]: 1744 conv_name = f"conv{dim}d" 1745 m = self.checkGraphModeOp( 1746 orig_m, 1747 self.img_data_dict[dim], 1748 f"quantized::conv{dim}d_relu(", 1749 tracing=tracing, 1750 ) 1751 1752 FileCheck().check_not(f"aten::conv{dim}d(").check_not( 1753 "aten::relu" 1754 ).check_not(f"quantized::conv{dim}d(").check_not( 1755 "quantized::relu(" 1756 ).run( 1757 m.graph 1758 ) 1759 1760 @skipIfNoFBGEMM 1761 def test_quantized_add_alpha(self): 1762 """Test quant fusion for multiple aten::add using same 1763 constant alpha as the third argument 1764 """ 1765 1766 class QuantizedAdd(torch.nn.Module): 1767 def __init__(self) -> None: 1768 super().__init__() 1769 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1770 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1771 1772 def forward(self, x, y): 1773 x = self.conv1(x) 1774 y = self.conv2(y) 1775 z = x + y 1776 w = y + z 1777 return z + w 1778 1779 data = [ 1780 [ 1781 torch.randn(1, 2, 5, 5, dtype=torch.float), 1782 torch.randn(1, 2, 5, 5, dtype=torch.float), 1783 ] 1784 ] 1785 for tracing in [True, False]: 1786 m = self.checkGraphModeOp(QuantizedAdd(), data, "quantized::add", tracing) 1787 FileCheck().check_count("quantized::add", 3, exactly=True).run(m.graph) 1788 FileCheck().check_not("aten::add").check_not("aten::add_").run(m.graph) 1789 1790 @skipIfNoFBGEMM 1791 def test_quantized_add_relu_alpha(self): 1792 """Test quant fusion for multiple aten::add using same 1793 constant alpha as the third argument in add_relu pattern 1794 """ 1795 1796 class AddRelu(torch.nn.Module): 1797 def __init__(self, inplace): 1798 super().__init__() 1799 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1800 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1801 self.relu = torch.nn.ReLU(inplace) 1802 1803 def forward(self, x, y): 1804 x = self.conv1(x) 1805 y = self.conv2(y) 1806 x = x + y 1807 x = self.relu(x) 1808 x = x + y 1809 return self.relu(x) 1810 1811 class InplaceAddRelu(torch.nn.Module): 1812 def __init__(self, inplace): 1813 super().__init__() 1814 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1815 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1816 self.relu = torch.nn.ReLU(inplace) 1817 1818 def forward(self, x, y): 1819 x = self.conv1(x) 1820 y = self.conv2(y) 1821 x += y 1822 x = self.relu(x) 1823 x += y 1824 return self.relu(x) 1825 1826 class AddFunctionalRelu(torch.nn.Module): 1827 def __init__(self) -> None: 1828 super().__init__() 1829 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1830 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1831 1832 def forward(self, x, y): 1833 x = self.conv1(x) 1834 y = self.conv2(y) 1835 x = x + y 1836 x = F.relu(x) 1837 x = x + y 1838 return F.relu(x) 1839 1840 class InplaceAddFunctionalRelu(torch.nn.Module): 1841 def __init__(self) -> None: 1842 super().__init__() 1843 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1844 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1845 1846 def forward(self, x, y): 1847 x = self.conv1(x) 1848 y = self.conv2(y) 1849 x += y 1850 x = F.relu(x) 1851 x += y 1852 return F.relu(x) 1853 1854 class AddInplaceFunctionalRelu(torch.nn.Module): 1855 def __init__(self) -> None: 1856 super().__init__() 1857 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1858 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1859 1860 def forward(self, x, y): 1861 x = self.conv1(x) 1862 y = self.conv2(y) 1863 x = x + y 1864 x = F.relu(x, True) 1865 x = x + y 1866 return F.relu(x, True) 1867 1868 class InplaceAddInplaceFunctionalRelu(torch.nn.Module): 1869 def __init__(self) -> None: 1870 super().__init__() 1871 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1872 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1873 1874 def forward(self, x, y): 1875 x = self.conv1(x) 1876 y = self.conv2(y) 1877 x += y 1878 x = F.relu(x, True) 1879 x += y 1880 return F.relu(x, True) 1881 1882 data = [ 1883 [ 1884 torch.rand((1, 2, 5, 5), dtype=torch.float), 1885 torch.rand((1, 2, 5, 5), dtype=torch.float), 1886 ] 1887 ] 1888 for m_orig in [ 1889 AddRelu(True), 1890 AddRelu(False), 1891 InplaceAddRelu(True), 1892 InplaceAddRelu(False), 1893 AddFunctionalRelu(), 1894 InplaceAddFunctionalRelu(), 1895 AddInplaceFunctionalRelu(), 1896 InplaceAddInplaceFunctionalRelu(), 1897 ]: 1898 for tracing in [True, False]: 1899 m = self.checkGraphModeOp( 1900 m_orig, data, "quantized::add_relu(", tracing=tracing 1901 ) 1902 FileCheck().check_count("quantized::add_relu(", 2, exactly=True).run( 1903 m.graph 1904 ) 1905 FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( 1906 "aten::relu(" 1907 ).check_not("aten::relu_(").check_not("quantized::add(").check_not( 1908 "quantized::relu(" 1909 ).run( 1910 m.graph 1911 ) 1912 1913 @skipIfNoFBGEMM 1914 def test_quantized_add(self): 1915 class QuantizedAdd(torch.nn.Module): 1916 def __init__(self) -> None: 1917 super().__init__() 1918 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1919 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1920 1921 def forward(self, x, y): 1922 x = self.conv1(x) 1923 y = self.conv2(y) 1924 return x + y 1925 1926 class QuantizedInplaceAdd(torch.nn.Module): 1927 def __init__(self) -> None: 1928 super().__init__() 1929 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 1930 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 1931 1932 def forward(self, x, y): 1933 x = self.conv1(x) 1934 y = self.conv2(y) 1935 x += y 1936 return x 1937 1938 class NonQuantizedAdd(torch.nn.Module): 1939 def forward(self, x, y): 1940 return x + y 1941 1942 class NonQuantizedInplaceAdd(torch.nn.Module): 1943 def forward(self, x, y): 1944 x += y 1945 return x 1946 1947 data = [ 1948 [ 1949 torch.randn(1, 2, 3, 3, dtype=torch.float), 1950 torch.randn(1, 2, 3, 3, dtype=torch.float), 1951 ] 1952 ] 1953 for m, quantized in [ 1954 (QuantizedAdd(), True), 1955 (QuantizedInplaceAdd(), True), 1956 (NonQuantizedAdd(), False), 1957 (NonQuantizedInplaceAdd(), False), 1958 ]: 1959 for tracing in [True, False]: 1960 op = "quantized::add" if quantized else "aten::add" 1961 m = self.checkGraphModeOp(m, data, op, tracing) 1962 # TODO: remove after refactor of checkGraphModeOp 1963 if quantized: 1964 FileCheck().check_not("aten::add").check_not("aten::add_").run( 1965 m.graph 1966 ) 1967 else: 1968 FileCheck().check_not("quantized::add").run(m.graph) 1969 1970 @skipIfNoFBGEMM 1971 def test_quantized_add_scalar(self): 1972 class QuantizedAddScalar(torch.nn.Module): 1973 def __init__(self) -> None: 1974 super().__init__() 1975 self.conv = torch.nn.Conv2d(2, 2, 2).float() 1976 1977 def forward(self, x): 1978 x = self.conv(x) 1979 return x + 3 1980 1981 class QuantizedInplaceAddScalar(torch.nn.Module): 1982 def __init__(self) -> None: 1983 super().__init__() 1984 self.conv = torch.nn.Conv2d(2, 2, 2).float() 1985 1986 def forward(self, x): 1987 x = self.conv(x) 1988 x += 3 1989 return x 1990 1991 class NonQuantizedAddScalar(torch.nn.Module): 1992 def forward(self, x): 1993 return x + 3 1994 1995 class NonQuantizedInplaceAddScalar(torch.nn.Module): 1996 def forward(self, x): 1997 x += 3 1998 return x 1999 2000 data = [[torch.randn(1, 2, 3, 3, dtype=torch.float)]] 2001 for m, quantized in [ 2002 (QuantizedAddScalar(), True), 2003 (QuantizedInplaceAddScalar(), True), 2004 (NonQuantizedAddScalar(), False), 2005 (NonQuantizedInplaceAddScalar(), False), 2006 ]: 2007 for tracing in [True, False]: 2008 op = "quantized::add_scalar" if quantized else "aten::add" 2009 # we don't check the numerical consistency for add_scalar 2010 # since it's not supported 2011 m = self.checkGraphModeOp(m, data, op, tracing, check=False) 2012 # TODO: remove after refactor of checkGraphModeOp 2013 if quantized: 2014 FileCheck().check_not("aten::add").check_not("aten::add_").run( 2015 m.graph 2016 ) 2017 else: 2018 FileCheck().check_not("quantized::add_scalar").run(m.graph) 2019 2020 @skipIfNoFBGEMM 2021 def test_quantized_add_relu(self): 2022 class AddRelu(torch.nn.Module): 2023 def __init__(self, inplace): 2024 super().__init__() 2025 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2026 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2027 self.relu = torch.nn.ReLU(inplace) 2028 2029 def forward(self, x, y): 2030 x = self.conv1(x) 2031 y = self.conv2(y) 2032 x = x + y 2033 return self.relu(x) 2034 2035 class InplaceAddRelu(torch.nn.Module): 2036 def __init__(self, inplace): 2037 super().__init__() 2038 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2039 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2040 self.relu = torch.nn.ReLU(inplace) 2041 2042 def forward(self, x, y): 2043 x = self.conv1(x) 2044 y = self.conv2(y) 2045 x += y 2046 return self.relu(x) 2047 2048 class AddFunctionalRelu(torch.nn.Module): 2049 def __init__(self) -> None: 2050 super().__init__() 2051 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2052 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2053 2054 def forward(self, x, y): 2055 x = self.conv1(x) 2056 y = self.conv2(y) 2057 x = x + y 2058 return F.relu(x) 2059 2060 class InplaceAddFunctionalRelu(torch.nn.Module): 2061 def __init__(self) -> None: 2062 super().__init__() 2063 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2064 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2065 2066 def forward(self, x, y): 2067 x = self.conv1(x) 2068 y = self.conv2(y) 2069 x += y 2070 return F.relu(x) 2071 2072 class AddInplaceFunctionalRelu(torch.nn.Module): 2073 def __init__(self) -> None: 2074 super().__init__() 2075 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2076 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2077 2078 def forward(self, x, y): 2079 x = self.conv1(x) 2080 y = self.conv2(y) 2081 x = x + y 2082 return F.relu(x, True) 2083 2084 class InplaceAddInplaceFunctionalRelu(torch.nn.Module): 2085 def __init__(self) -> None: 2086 super().__init__() 2087 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2088 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2089 2090 def forward(self, x, y): 2091 x = self.conv1(x) 2092 y = self.conv2(y) 2093 x += y 2094 return F.relu(x, True) 2095 2096 data = [ 2097 [ 2098 torch.rand((1, 2, 5, 5), dtype=torch.float), 2099 torch.rand((1, 2, 5, 5), dtype=torch.float), 2100 ] 2101 ] 2102 for m in [ 2103 AddRelu(True), 2104 AddRelu(False), 2105 InplaceAddRelu(True), 2106 InplaceAddRelu(False), 2107 AddFunctionalRelu(), 2108 InplaceAddFunctionalRelu(), 2109 AddInplaceFunctionalRelu(), 2110 InplaceAddInplaceFunctionalRelu(), 2111 ]: 2112 for tracing in [True, False]: 2113 m = self.checkGraphModeOp(m, data, "quantized::add_relu(", tracing) 2114 FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( 2115 "aten::relu(" 2116 ).check_not("aten::relu_(").check_not("quantized::add(").check_not( 2117 "quantized::relu(" 2118 ).run( 2119 m.graph 2120 ) 2121 2122 @skipIfNoFBGEMM 2123 def test_quantized_add_scalar_relu(self): 2124 class AddScalarRelu(torch.nn.Module): 2125 def __init__(self, inplace): 2126 super().__init__() 2127 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2128 self.relu = torch.nn.ReLU(inplace) 2129 2130 def forward(self, x): 2131 x = self.conv(x) 2132 return self.relu(x + 3) 2133 2134 class InplaceAddScalarRelu(torch.nn.Module): 2135 def __init__(self, inplace): 2136 super().__init__() 2137 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2138 self.relu = torch.nn.ReLU(inplace) 2139 2140 def forward(self, x): 2141 x = self.conv(x) 2142 x += 3 2143 return self.relu(x) 2144 2145 class AddScalarFunctionalRelu(torch.nn.Module): 2146 def __init__(self) -> None: 2147 super().__init__() 2148 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2149 2150 def forward(self, x): 2151 x = self.conv(x) 2152 return F.relu(x + 3) 2153 2154 class InplaceAddScalarFunctionalRelu(torch.nn.Module): 2155 def __init__(self) -> None: 2156 super().__init__() 2157 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2158 2159 def forward(self, x): 2160 x = self.conv(x) 2161 x += 3 2162 return F.relu(x) 2163 2164 class AddScalarInplaceFunctionalRelu(torch.nn.Module): 2165 def __init__(self) -> None: 2166 super().__init__() 2167 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2168 2169 def forward(self, x): 2170 x = self.conv(x) 2171 return F.relu(x + 3, True) 2172 2173 class InplaceAddScalarInplaceFunctionalRelu(torch.nn.Module): 2174 def __init__(self) -> None: 2175 super().__init__() 2176 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2177 2178 def forward(self, x): 2179 x = self.conv(x) 2180 x += 3 2181 return F.relu(x, True) 2182 2183 data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]] 2184 for m in [ 2185 AddScalarRelu(True), 2186 AddScalarRelu(False), 2187 InplaceAddScalarRelu(True), 2188 InplaceAddScalarRelu(False), 2189 AddScalarFunctionalRelu(), 2190 InplaceAddScalarFunctionalRelu(), 2191 AddScalarInplaceFunctionalRelu(), 2192 InplaceAddScalarInplaceFunctionalRelu(), 2193 ]: 2194 for tracing in [True, False]: 2195 # quantized::add_scalar_relu or quantized::add_scalar_relu_out 2196 # TODO: split this after refactor of checkGraphModeOp 2197 m = self.checkGraphModeOp( 2198 m, data, "quantized::add_scalar_relu", tracing, check=False 2199 ) 2200 FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( 2201 "aten::relu(" 2202 ).check_not("aten::relu_(").check_not( 2203 "quantized::add_scalar(" 2204 ).check_not( 2205 "quantized::relu(" 2206 ).run( 2207 m.graph 2208 ) 2209 2210 @skipIfNoFBGEMM 2211 def test_quantized_cat(self): 2212 """quantization of the output of cat will be depend on the 2213 input of cat. we only quantize the output of cat when its inputs are quantized. 2214 """ 2215 2216 class QuantizedCat(torch.nn.Module): 2217 def __init__(self) -> None: 2218 super().__init__() 2219 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2220 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2221 2222 def forward(self, x, y): 2223 x = self.conv1(x) 2224 y = self.conv2(y) 2225 return torch.cat([x, y], 1) 2226 2227 class NonQuantizedCat(torch.nn.Module): 2228 def forward(self, x, y): 2229 return torch.cat([x, y], 1) 2230 2231 data = [ 2232 [ 2233 torch.randn(1, 2, 5, 5, dtype=torch.float), 2234 torch.randn(1, 2, 5, 5, dtype=torch.float), 2235 ] 2236 ] 2237 for tracing in [True, False]: 2238 m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing) 2239 FileCheck().check_not("aten::cat").run(m.graph) 2240 2241 m = self.checkGraphModeOp(NonQuantizedCat(), data, "aten::cat", tracing) 2242 FileCheck().check_not("quantized::cat").run(m.graph) 2243 2244 @skipIfNoFBGEMM 2245 def test_qbatch_norm(self): 2246 bn_module = { 2247 1: torch.nn.BatchNorm1d, 2248 2: torch.nn.BatchNorm2d, 2249 3: torch.nn.BatchNorm3d, 2250 } 2251 2252 class M(torch.nn.Module): 2253 def __init__(self, dim): 2254 super().__init__() 2255 self.bn = bn_module[dim](3).to(torch.float) 2256 2257 def forward(self, x): 2258 return self.bn(x) 2259 2260 options = itertools.product([True, False], [1, 2, 3]) 2261 for tracing, dim in options: 2262 model = self.checkGraphModeOp( 2263 M(dim), self.img_data_dict[dim], "quantized::batch_norm", tracing 2264 ) 2265 2266 FileCheck().check_not("aten::batch_norm").run(model.graph) 2267 2268 @skipIfNoFBGEMM 2269 def test_qbatch_norm_relu_BNRelu(self): 2270 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 2271 2272 class BNRelu(torch.nn.Module): 2273 def __init__(self, dim, inplace): 2274 super().__init__() 2275 self.bn = bn_module[dim](3).to(torch.float) 2276 self.relu = torch.nn.ReLU(inplace=inplace) 2277 2278 def forward(self, x): 2279 return self.relu(self.bn(x)) 2280 2281 options = itertools.product([True, False], [2, 3]) 2282 for tracing, dim in options: 2283 for instance in [BNRelu(dim, True), BNRelu(dim, False)]: 2284 model = self.checkGraphModeOp( 2285 instance, 2286 self.img_data_dict[dim], 2287 "quantized::batch_norm_relu", 2288 tracing, 2289 ) 2290 FileCheck().check_not("aten::batch_norm").check_not( 2291 "aten::relu" 2292 ).check_not("aten::relu_").run(model.graph) 2293 2294 @skipIfNoFBGEMM 2295 def test_qbatch_norm_relu_BNFuncRelu(self): 2296 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 2297 2298 class BNFuncRelu(torch.nn.Module): 2299 def __init__(self, dim): 2300 super().__init__() 2301 self.bn = bn_module[dim](3).to(torch.float) 2302 2303 def forward(self, x): 2304 return F.relu(self.bn(x), False) 2305 2306 options = itertools.product([True, False], [2, 3]) 2307 for tracing, dim in options: 2308 instance = BNFuncRelu(dim) 2309 model = self.checkGraphModeOp( 2310 instance, self.img_data_dict[dim], "quantized::batch_norm_relu", tracing 2311 ) 2312 FileCheck().check_not("aten::batch_norm").check_not("aten::relu").check_not( 2313 "aten::relu_" 2314 ).run(model.graph) 2315 2316 @skipIfNoFBGEMM 2317 def test_qbatch_norm_relu_BNFuncInplaceRelu(self): 2318 bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} 2319 2320 class BNFuncInplaceRelu(torch.nn.Module): 2321 def __init__(self, dim): 2322 super().__init__() 2323 self.bn = bn_module[dim](3).to(torch.float) 2324 2325 def forward(self, x): 2326 return F.relu(self.bn(x), True) 2327 2328 options = itertools.product([True, False], [2, 3]) 2329 for tracing, dim in options: 2330 instance = BNFuncInplaceRelu(dim) 2331 model = self.checkGraphModeOp( 2332 instance, self.img_data_dict[dim], "quantized::batch_norm_relu", tracing 2333 ) 2334 FileCheck().check_not("aten::batch_norm").check_not("aten::relu").check_not( 2335 "aten::relu_" 2336 ).run(model.graph) 2337 2338 @skipIfNoFBGEMM 2339 def test_quantized_mul(self): 2340 class QuantizedMul(torch.nn.Module): 2341 def __init__(self) -> None: 2342 super().__init__() 2343 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2344 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2345 2346 def forward(self, x, y): 2347 x = self.conv1(x) 2348 y = self.conv2(y) 2349 return x * y 2350 2351 class QuantizedInplaceMul(torch.nn.Module): 2352 def __init__(self) -> None: 2353 super().__init__() 2354 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2355 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2356 2357 def forward(self, x, y): 2358 x = self.conv1(x) 2359 y = self.conv2(y) 2360 x *= y 2361 return x 2362 2363 class NonQuantizedMul(torch.nn.Module): 2364 def forward(self, x, y): 2365 return x * y 2366 2367 class NonQuantizedInplaceMul(torch.nn.Module): 2368 def forward(self, x, y): 2369 x *= y 2370 return x 2371 2372 data = [ 2373 [ 2374 torch.randn(1, 2, 10, 10, dtype=torch.float), 2375 torch.randn(1, 2, 10, 10, dtype=torch.float), 2376 ] 2377 ] 2378 for m, quantized in [ 2379 (QuantizedMul(), True), 2380 (QuantizedInplaceMul(), True), 2381 (NonQuantizedMul(), False), 2382 (NonQuantizedInplaceMul(), False), 2383 ]: 2384 for tracing in [True, False]: 2385 op = "quantized::mul" if quantized else "aten::mul" 2386 m = self.checkGraphModeOp(m, data, op, tracing) 2387 # TODO: remove after refactor of checkGraphModeOp 2388 if quantized: 2389 FileCheck().check_not("aten::mul").check_not("aten::mul_").run( 2390 m.graph 2391 ) 2392 else: 2393 FileCheck().check_not("quantized::mul").run(m.graph) 2394 2395 @skipIfNoFBGEMM 2396 def test_quantized_mul_scalar(self): 2397 class QuantizedMulScalar(torch.nn.Module): 2398 def __init__(self) -> None: 2399 super().__init__() 2400 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2401 2402 def forward(self, x): 2403 x = self.conv(x) 2404 return x * 3 2405 2406 class QuantizedInplaceMulScalar(torch.nn.Module): 2407 def __init__(self) -> None: 2408 super().__init__() 2409 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2410 2411 def forward(self, x): 2412 x = self.conv(x) 2413 x *= 3 2414 return x 2415 2416 class NonQuantizedMulScalar(torch.nn.Module): 2417 def forward(self, x): 2418 return x * 3 2419 2420 class NonQuantizedInplaceMulScalar(torch.nn.Module): 2421 def forward(self, x): 2422 x *= 3 2423 return x 2424 2425 data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]] 2426 for m, quantized in [ 2427 (QuantizedMulScalar(), True), 2428 (QuantizedInplaceMulScalar(), True), 2429 (NonQuantizedMulScalar(), False), 2430 (NonQuantizedInplaceMulScalar(), False), 2431 ]: 2432 for tracing in [True, False]: 2433 op = "quantized::mul_scalar" if quantized else "aten::mul" 2434 # we don't check the numerical consistency for add_scalar 2435 # since it's not supported 2436 m = self.checkGraphModeOp(m, data, op, tracing, check=False) 2437 # TODO: remove after refactor of checkGraphModeOp 2438 if quantized: 2439 FileCheck().check_not("aten::mul").check_not("aten::mul_").run( 2440 m.graph 2441 ) 2442 else: 2443 FileCheck().check_not("quantized::mul_scalar").run(m.graph) 2444 2445 @skipIfNoFBGEMM 2446 def test_quantized_mul_relu(self): 2447 class MulRelu(torch.nn.Module): 2448 def __init__(self, inplace): 2449 super().__init__() 2450 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2451 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2452 self.relu = torch.nn.ReLU(inplace) 2453 2454 def forward(self, x, y): 2455 x = self.conv1(x) 2456 y = self.conv2(y) 2457 x = x * y 2458 return self.relu(x) 2459 2460 class InplaceMulRelu(torch.nn.Module): 2461 def __init__(self, inplace): 2462 super().__init__() 2463 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2464 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2465 self.relu = torch.nn.ReLU(inplace) 2466 2467 def forward(self, x, y): 2468 x = self.conv1(x) 2469 y = self.conv2(y) 2470 x *= y 2471 return self.relu(x) 2472 2473 class MulFunctionalRelu(torch.nn.Module): 2474 def __init__(self) -> None: 2475 super().__init__() 2476 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2477 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2478 2479 def forward(self, x, y): 2480 x = self.conv1(x) 2481 y = self.conv2(y) 2482 x = x * y 2483 return F.relu(x) 2484 2485 class InplaceMulFunctionalRelu(torch.nn.Module): 2486 def __init__(self) -> None: 2487 super().__init__() 2488 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2489 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2490 2491 def forward(self, x, y): 2492 x = self.conv1(x) 2493 y = self.conv2(y) 2494 x *= y 2495 return F.relu(x) 2496 2497 class MulInplaceFunctionalRelu(torch.nn.Module): 2498 def __init__(self) -> None: 2499 super().__init__() 2500 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2501 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2502 2503 def forward(self, x, y): 2504 x = self.conv1(x) 2505 y = self.conv2(y) 2506 x = x * y 2507 return F.relu(x, True) 2508 2509 class InplaceMulInplaceFunctionalRelu(torch.nn.Module): 2510 def __init__(self) -> None: 2511 super().__init__() 2512 self.conv1 = torch.nn.Conv2d(2, 2, 2).float() 2513 self.conv2 = torch.nn.Conv2d(2, 2, 2).float() 2514 2515 def forward(self, x, y): 2516 x = self.conv1(x) 2517 y = self.conv2(y) 2518 x *= y 2519 return F.relu(x, True) 2520 2521 data = [ 2522 [ 2523 torch.rand((1, 2, 5, 5), dtype=torch.float), 2524 torch.rand((1, 2, 5, 5), dtype=torch.float), 2525 ] 2526 ] 2527 for m in [ 2528 MulRelu(True), 2529 MulRelu(False), 2530 InplaceMulRelu(True), 2531 InplaceMulRelu(False), 2532 MulFunctionalRelu(), 2533 InplaceMulFunctionalRelu(), 2534 MulInplaceFunctionalRelu(), 2535 InplaceMulInplaceFunctionalRelu(), 2536 ]: 2537 for tracing in [True, False]: 2538 m = self.checkGraphModeOp(m, data, "quantized::mul_relu(", tracing) 2539 FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not( 2540 "aten::relu(" 2541 ).check_not("aten::relu_(").check_not("quantized::mul(").check_not( 2542 "quantized::relu(" 2543 ).run( 2544 m.graph 2545 ) 2546 2547 @skipIfNoFBGEMM 2548 def test_quantized_mul_scalar_relu(self): 2549 class MulScalarRelu(torch.nn.Module): 2550 def __init__(self, inplace): 2551 super().__init__() 2552 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2553 self.relu = torch.nn.ReLU(inplace) 2554 2555 def forward(self, x): 2556 x = self.conv(x) 2557 return self.relu(x * 3) 2558 2559 class InplaceMulScalarRelu(torch.nn.Module): 2560 def __init__(self, inplace): 2561 super().__init__() 2562 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2563 self.relu = torch.nn.ReLU(inplace) 2564 2565 def forward(self, x): 2566 x = self.conv(x) 2567 x *= 3 2568 return self.relu(x) 2569 2570 class MulScalarFunctionalRelu(torch.nn.Module): 2571 def __init__(self) -> None: 2572 super().__init__() 2573 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2574 2575 def forward(self, x): 2576 x = self.conv(x) 2577 return F.relu(x * 3) 2578 2579 class InplaceMulScalarFunctionalRelu(torch.nn.Module): 2580 def __init__(self) -> None: 2581 super().__init__() 2582 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2583 2584 def forward(self, x): 2585 x = self.conv(x) 2586 x *= 3 2587 return F.relu(x) 2588 2589 class MulScalarInplaceFunctionalRelu(torch.nn.Module): 2590 def __init__(self) -> None: 2591 super().__init__() 2592 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2593 2594 def forward(self, x): 2595 x = self.conv(x) 2596 return F.relu(x * 3, True) 2597 2598 class InplaceMulScalarInplaceFunctionalRelu(torch.nn.Module): 2599 def __init__(self) -> None: 2600 super().__init__() 2601 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2602 2603 def forward(self, x): 2604 x = self.conv(x) 2605 x *= 3 2606 return F.relu(x, True) 2607 2608 data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]] 2609 for m in [ 2610 MulScalarRelu(True), 2611 MulScalarRelu(False), 2612 InplaceMulScalarRelu(True), 2613 InplaceMulScalarRelu(False), 2614 MulScalarFunctionalRelu(), 2615 InplaceMulScalarFunctionalRelu(), 2616 MulScalarInplaceFunctionalRelu(), 2617 InplaceMulScalarInplaceFunctionalRelu(), 2618 ]: 2619 for tracing in [True, False]: 2620 # quantized::mul_scalar_relu or quantized::mul_scalar_relu_out 2621 m = self.checkGraphModeOp( 2622 m, data, "quantized::mul_scalar_relu", tracing, check=False 2623 ) 2624 FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not( 2625 "aten::relu(" 2626 ).check_not("aten::relu_(").check_not( 2627 "quantized::mul_scalar(" 2628 ).check_not( 2629 "quantized::relu(" 2630 ).run( 2631 m.graph 2632 ) 2633 2634 @override_qengines 2635 def test_hardswish(self): 2636 class FunctionalHardswish(torch.nn.Module): 2637 def __init__(self, inplace): 2638 super().__init__() 2639 self.inplace = inplace 2640 2641 def forward(self, input): 2642 return torch.nn.functional.hardswish(input, inplace=self.inplace) 2643 2644 modules = [ 2645 torch.nn.Hardswish(), 2646 FunctionalHardswish(True), 2647 FunctionalHardswish(False), 2648 ] 2649 2650 for test_case in itertools.product([True, False], modules): 2651 tracing, m = test_case 2652 m = self.checkGraphModeOp( 2653 m, self.img_data_2d, "quantized::hardswish", tracing 2654 ) 2655 FileCheck().check_not("aten::hardswish").check_not("aten::hardswish_").run( 2656 m.graph 2657 ) 2658 2659 @override_qengines 2660 def test_elu(self): 2661 class FunctionalELU(torch.nn.Module): 2662 def __init__(self, inplace=False): 2663 super().__init__() 2664 self.inplace = inplace 2665 2666 def forward(self, input): 2667 return torch.nn.functional.elu(input, inplace=self.inplace) 2668 2669 modules = [torch.nn.ELU, FunctionalELU] 2670 for test_case in itertools.product([True, False], [True, False], modules): 2671 tracing, inplace, mod_class = test_case 2672 m = mod_class(inplace=inplace) 2673 m = self.checkGraphModeOp(m, self.img_data_2d, "quantized::elu", tracing) 2674 FileCheck().check_not("aten::elu").check_not("aten::elu_").run(m.graph) 2675 2676 @override_qengines 2677 def test_layer_norm(self): 2678 data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)] for _ in range(2)] 2679 layer_norm = torch.nn.LayerNorm([2, 5, 5]) 2680 for tracing in [True, False]: 2681 m = self.checkGraphModeOp( 2682 layer_norm, data, "quantized::layer_norm", tracing 2683 ) 2684 FileCheck().check_not("aten::layer_norm").run(m.graph) 2685 2686 @override_qengines 2687 def test_group_norm(self): 2688 data = [[torch.rand((1, 4, 5, 5), dtype=torch.float)] for _ in range(2)] 2689 group_norm = torch.nn.GroupNorm(2, 4) 2690 for tracing in [True, False]: 2691 m = self.checkGraphModeOp( 2692 group_norm, data, "quantized::group_norm", tracing 2693 ) 2694 FileCheck().check_not("aten::group_norm").run(m.graph) 2695 2696 @override_qengines 2697 def test_instance_norm(self): 2698 data_1d = [[torch.rand((1, 4, 5), dtype=torch.float)] for _ in range(2)] 2699 data_2d = [[torch.rand((1, 4, 5, 1), dtype=torch.float)] for _ in range(2)] 2700 data_3d = [[torch.rand((1, 4, 5, 1, 1), dtype=torch.float)] for _ in range(2)] 2701 data = {1: data_1d, 2: data_2d, 3: data_3d} 2702 instance_norm_modules = { 2703 1: torch.nn.InstanceNorm1d, 2704 2: torch.nn.InstanceNorm2d, 2705 3: torch.nn.InstanceNorm3d, 2706 } 2707 2708 options = itertools.product([1, 2, 3], [True, False]) 2709 for dim, tracing in options: 2710 instance_norm = instance_norm_modules[dim](4) 2711 m = self.checkGraphModeOp( 2712 instance_norm, data[dim], "quantized::instance_norm", tracing 2713 ) 2714 FileCheck().check_not("aten::instance_norm").run(m.graph) 2715 2716 @skipIfNoFBGEMM 2717 def test_dequantize_tuple(self): 2718 """Make sure dequantize can support Tuple of tensor""" 2719 2720 class M(torch.nn.Module): 2721 def __init__(self) -> None: 2722 super().__init__() 2723 self.conv1 = torch.nn.Conv2d(3, 3, 3).float() 2724 self.conv2 = torch.nn.Conv2d(3, 3, 3).float() 2725 2726 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 2727 x1 = self.conv1(x) 2728 x2 = self.conv2(x) 2729 return x1, x2 2730 2731 for tracing in [True, False]: 2732 self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing) 2733 2734 @skipIfNoFBGEMM 2735 def test_clamp(self): 2736 class M(torch.nn.Module): 2737 def __init__(self) -> None: 2738 super().__init__() 2739 self.conv = torch.nn.Conv2d(2, 2, 2).float() 2740 self.relu6 = torch.nn.ReLU6() 2741 self.relu6_ = torch.nn.ReLU6(True) 2742 self.hardtanh = torch.nn.Hardtanh() 2743 self.hardtanh_ = torch.nn.Hardtanh(inplace=True) 2744 2745 def forward(self, x): 2746 x = self.conv(x) 2747 x = self.relu6(x) 2748 self.relu6_(x) 2749 x = F.relu6(x) 2750 x = torch.clamp(x, -3, 3) 2751 x = x.clamp(-2.5, 2.5) 2752 # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready 2753 x = self.hardtanh(x) 2754 self.hardtanh_(x) 2755 x = F.hardtanh(x) 2756 F.hardtanh_(x) 2757 return x 2758 2759 data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]] 2760 options = itertools.product( 2761 ["aten::clamp", "aten::hardtanh", "aten::hardtanh_"], [True, False] 2762 ) 2763 for op, tracing in options: 2764 m = self.checkGraphModeOp(M(), data, op, tracing) 2765 FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( 2766 m.graph 2767 ) 2768 2769 FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) 2770 2771 def test_general_shape_ops(self): 2772 """A test that checks dequantize will be swapped for 2773 all supported general shape ops like aten::flatten 2774 without actually checking for execution of these ops 2775 """ 2776 2777 class M(torch.nn.Module): 2778 def __init__(self) -> None: 2779 super().__init__() 2780 self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) 2781 self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) 2782 self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) 2783 self.dropout = torch.nn.Dropout() 2784 self.conv1 = torch.nn.Conv2d(3, 3, 3) 2785 self.conv2 = torch.nn.Conv2d(3, 3, 3) 2786 self.relu = torch.nn.ReLU() 2787 2788 def forward(self, x): 2789 x = self.conv1(x) 2790 # add_scalar 2791 x = x + 3 2792 # mul_scalar 2793 x = x * 3 2794 # add_scalar_out 2795 x += 3 2796 # mul_scalar_out 2797 x *= 3 2798 # add_scalar_relu 2799 x = x + 3 2800 x = F.relu(x) 2801 # add_scalar_relu_out 2802 x += 3 2803 x = F.relu(x) 2804 # mul_scalar_relu 2805 x = x * 3 2806 x = F.relu(x) 2807 # mul_scalar_relu_out 2808 x *= 3 2809 x = F.relu(x) 2810 x = self.maxpool1d(x) 2811 x = self.maxpool2d(x) 2812 x = self.maxpool3d(x) 2813 x = torch.flatten(x) 2814 x = torch.max(x) 2815 x = torch.min(x) 2816 x = x.reshape([-1]) 2817 x = x.resize_(1, 1, x.numel()) 2818 x = x.view(-1) 2819 # prim::ListConstruct 2820 xs = [x, x] 2821 # prim::ListUnpack 2822 x, y = xs 2823 # prim::TupleConstruct 2824 xs = (x, x) 2825 # prim::TupleUnpack 2826 x, y = xs 2827 x = x.transpose(1, 2) 2828 x = x.contiguous() 2829 x, y = torch.chunk(x, 2) 2830 x = F.dropout(x) 2831 x = self.dropout(x) 2832 x, _ = torch.sort(x) 2833 x = x.permute(0, 2, 3, 1) 2834 x = torch.repeat_interleave(x, 3, 1) 2835 x = self.relu(x) 2836 x = F.relu(x) 2837 x.relu_() 2838 x = x.squeeze(0) 2839 x.squeeze_(0) 2840 x = torch.squeeze(x, 0) 2841 x = x.unsqueeze(0) 2842 x.unsqueeze_(0) 2843 x = torch.unsqueeze(x, 0) 2844 x = x.detach() 2845 x.detach_() 2846 x = x.repeat(4, 2) 2847 y = [] 2848 y.append(x) 2849 z = torch.stack(y, 0) 2850 z = [z, z] 2851 x, _ = z 2852 x = self.conv2(x) 2853 return x 2854 2855 data = torch.rand(1, 3, 10, 10) 2856 # This model is not executable since we just put all ops 2857 # in the same forward, therefore we only test scripting 2858 m = torch.jit.script(M()) 2859 qconfig = script_qconfig(default_qconfig) 2860 # dummy data to suppress warning 2861 get_forward(qconfig.activation)(data) 2862 get_forward(qconfig.weight)(data) 2863 2864 m = wrap_cpp_module( 2865 torch._C._jit_pass_insert_observers( 2866 m._c, "forward", {"": qconfig}, inplace=False 2867 ) 2868 ) 2869 m = convert_jit(m) 2870 # This checks that the dequantize from the output of first conv 2871 # is being propagated to the end, so that we don't insert extra 2872 # observers and also successfully fused two quantized::conv2d 2873 # patterns 2874 # one quantize_per_tensor for input 2875 FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( 2876 m.graph 2877 ) 2878 2879 FileCheck().check_count("quantized::conv2d(", 2, exactly=True).run(m.graph) 2880 2881 FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) 2882 2883 FileCheck().check("quantized::add_scalar").check("quantized::mul_scalar").run( 2884 m.graph 2885 ) 2886 2887 def test_general_value_ops(self): 2888 """ A test that checks correct patterns are produced for 2889 all supported general value ops like aten::avg_pool2d \ 2890 without actually checking for execution of these ops 2891 """ 2892 2893 class M(torch.nn.Module): 2894 def __init__(self) -> None: 2895 super().__init__() 2896 self.conv = torch.nn.Conv2d(3, 3, 3) 2897 self.avg_pool1d = torch.nn.AvgPool1d(3) 2898 self.avg_pool2d = torch.nn.AvgPool2d(3) 2899 self.avg_pool3d = torch.nn.AvgPool3d(3) 2900 self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d(1) 2901 self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) 2902 self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) 2903 self.leaky_relu = torch.nn.LeakyReLU() 2904 self.hardsigmoid = torch.nn.Hardsigmoid() 2905 self.sigmoid = torch.nn.Sigmoid() 2906 self.tanh = torch.nn.Tanh() 2907 2908 def forward(self, x): 2909 x = self.conv(x) 2910 x = self.avg_pool1d(x) 2911 x = self.avg_pool2d(x) 2912 x = self.avg_pool3d(x) 2913 x = self.adaptive_avg_pool1d(x) 2914 x = self.adaptive_avg_pool2d(x) 2915 x = self.adaptive_avg_pool3d(x) 2916 x = F.avg_pool1d(x, 3) 2917 x = F.avg_pool2d(x, 3) 2918 x = F.avg_pool3d(x, 3) 2919 x = F.adaptive_avg_pool1d(x, (1)) 2920 x = F.adaptive_avg_pool2d(x, (1, 1)) 2921 x = F.adaptive_avg_pool3d(x, (1, 1, 1)) 2922 x = torch.mean(x) 2923 x = torch.mean(x, [2, 3], False) 2924 x = x.mean() 2925 x = x.mean([2, 3], True) 2926 # interpolate node will introduce 3 quantize_per_tensor ops 2927 x = F.interpolate(x, 4, mode="nearest") # interpolate node 2928 x = F.upsample(x, (32, 32)) # interpolate node 2929 x = F.upsample_nearest(x, (32, 32)) # interpolate node 2930 x = F.interpolate(x, 4, mode="linear") # common node 2931 x = F.upsample_bilinear(x, (32, 32)) # common node 2932 x = self.leaky_relu(x) 2933 x = F.leaky_relu(x) 2934 x.leaky_relu_() 2935 x = self.hardsigmoid(x) 2936 x = F.hardsigmoid(x) 2937 x.hardsigmoid_() 2938 x = self.sigmoid(x) 2939 x = torch.sigmoid(x) 2940 # F.sigmoid is deprecated 2941 x = x.sigmoid() 2942 x.sigmoid_() 2943 x = self.tanh(x) 2944 # F.tanh is deprecated 2945 x = torch.tanh(x) 2946 x = x.tanh() 2947 x.tanh_() 2948 x = self.conv(x) 2949 return x 2950 2951 # This model is not executable since we just put all ops 2952 # in the same forward, therefore we only test scripting 2953 m = torch.jit.script(M()) 2954 qconfig = script_qconfig(default_qconfig) 2955 # dummy data to suppress warning 2956 data = torch.rand(1, 3, 10, 10) 2957 get_forward(qconfig.activation)(data) 2958 get_forward(qconfig.weight)(data) 2959 2960 m = wrap_cpp_module( 2961 torch._C._jit_pass_insert_observers( 2962 m._c, "forward", {"": qconfig}, inplace=False 2963 ) 2964 ) 2965 # Checking the model before fianlize contain unfused patterns 2966 # that numerically matches the model after quantize by checking 2967 # number of aten::quantize_per_tensor functions 2968 # conv has 3 quantize_per_tensor for activations and 1 for weight 2969 # and for N general value op between conv we should have 2970 2971 # N + 1 quantize_per_tensor between these ops 2972 m1 = convert_jit(m, debug=True) 2973 # NB: This Needs to be updated when we add more ops to test 2974 # mapping from number of quant for the op to the number of these ops 2975 # for example, for `3` in the key means for this type of op 2976 # we'll have 3 quantize_per_tensor 2977 num_op_by_num_quant = {1: 32, 2: 2, 3: 3} 2978 num_quantize_per_tensor = 1 # for output 2979 for num_quant, num_op in num_op_by_num_quant.items(): 2980 num_quantize_per_tensor += num_op * num_quant 2981 num_quantize_per_tensor -= 4 # constant propagation removes some prepacks 2982 FileCheck().check_count( 2983 "aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True 2984 ).run(m1.graph) 2985 2986 # This checks that the dequantize from the output of first conv 2987 # is being propagated to the end, so that we don't insert extra 2988 # observers and also successfully fused two quantized::conv2d 2989 # patterns 2990 # one quantize_per_tensor for input 2991 m2 = convert_jit(m, debug=False) 2992 FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True).run( 2993 m2.graph 2994 ) 2995 FileCheck().check_count("quantized::conv2d(", 2, exactly=True).check( 2996 "aten::dequantize(" 2997 ).run(m2.graph) 2998 2999 @override_qengines 3000 def test_conv_with_benchmark_flag(self): 3001 r"""Verifies that convolutions get quantized when 3002 torch.backends.cudnn.benchmark is enabled 3003 """ 3004 if not qengine_is_qnnpack(): 3005 return 3006 with torch.backends.cudnn.flags(enabled=True): 3007 m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) 3008 m.eval() 3009 m = torch.jit.trace(m, torch.rand(4, 1, 4, 4)) 3010 qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 3011 prepared_model = torch.ao.quantization.prepare_jit(m, {"": qconfig}) 3012 prepared_model(torch.rand(4, 1, 4, 4)) 3013 converted_model = torch.ao.quantization.convert_jit(prepared_model) 3014 FileCheck().check("quantized::conv2d").run(converted_model.graph) 3015 3016 @skipIfNoFBGEMM 3017 def test_cat_linear(self): 3018 class LinearModel(torch.nn.Module): 3019 def __init__(self) -> None: 3020 super().__init__() 3021 self.weight = torch.randn(5, 5) 3022 3023 def forward(self, x, y): 3024 a = torch.cat([x, y]) 3025 b = F.linear(a, self.weight) 3026 c = F.linear(b, self.weight) 3027 return b, c 3028 3029 model = LinearModel().eval() 3030 qconfig = {"": default_qconfig} 3031 float_model = torch.jit.script(model) 3032 prepared_model = prepare_jit(float_model, qconfig) 3033 prepared_model(torch.rand(5, 5), torch.rand(5, 5)) 3034 converted_model = convert_jit(prepared_model) 3035 FileCheck().check("quantized::linear").check("quantized::linear").run( 3036 converted_model.graph 3037 ) 3038 3039 3040class TestQuantizeDynamicJitPasses(QuantizationTestCase): 3041 def test_prepare_dynamic(self): 3042 class M(torch.nn.Module): 3043 def __init__(self) -> None: 3044 super().__init__() 3045 self.fc = torch.nn.Linear(5, 5) 3046 3047 def forward(self, x): 3048 return self.fc(x) 3049 3050 model = torch.jit.script(M()) 3051 for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: 3052 m = prepare_dynamic_jit(model, {"": qconfig}) 3053 3054 # observer for weight 3055 assert len(attrs_with_prefix(m.fc, "_observer_")) == 1 3056 3057 if qconfig == float16_dynamic_qconfig: 3058 observer_name = 'PlaceholderObserver = prim::GetAttr[name="_observer_' 3059 FileCheck().check(observer_name).run(m.fc.graph) 3060 else: 3061 # for input of FC for dynamic quant 3062 assert len(attrs_with_prefix(m, "_observer_")) == 1 3063 observer_name = 'Observer = prim::GetAttr[name="_observer_' 3064 FileCheck().check(observer_name).check( 3065 'prim::GetAttr[name="fc"]' 3066 ).check("prim::CallMethod").check_not(observer_name).run(m.graph) 3067 3068 def test_prepare_dynamic_child_qconfig(self): 3069 class Sub(torch.nn.Module): 3070 def __init__(self) -> None: 3071 super().__init__() 3072 self.fc = torch.nn.Linear(5, 5) 3073 3074 def forward(self, x): 3075 return self.fc(x) 3076 3077 class M(torch.nn.Module): 3078 def __init__(self) -> None: 3079 super().__init__() 3080 self.conv = torch.nn.Conv2d(3, 5, 3) 3081 self.sub = Sub() 3082 3083 def forward(self, x): 3084 return self.sub(self.conv(x)) 3085 3086 m = torch.jit.script(M()) 3087 # only quantize child module. 3088 m = prepare_dynamic_jit(m, {"sub.fc": default_dynamic_qconfig}) 3089 3090 # input of sub for dynamic quant 3091 assert len(attrs_with_prefix(m, "_observer_")) == 1 3092 # not quantized 3093 assert len(attrs_with_prefix(m.conv, "_observer_")) == 0 3094 # no observers since we observe in the outer most call site 3095 assert len(attrs_with_prefix(m.sub, "_observer_")) == 0 3096 # weight of linear 3097 assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1 3098 FileCheck().check('prim::GetAttr[name="sub').check("prim::CallMethod").check( 3099 'Observer = prim::GetAttr[name="_observer_' 3100 ).check("prim::CallMethod").check_not( 3101 'Observer = prim::GetAttr[name="_observer_' 3102 ).run( 3103 m.graph 3104 ) 3105 3106 def test_insert_quant_dequant_linear_dynamic(self): 3107 class M(torch.nn.Module): 3108 def __init__(self) -> None: 3109 super().__init__() 3110 self.fc1 = torch.nn.Linear(5, 5).float() 3111 self.fc2 = torch.nn.Linear(5, 5).float() 3112 3113 def forward(self, x): 3114 x = self.fc1(x) 3115 return self.fc2(x) 3116 3117 for is_per_channel in [True, False]: 3118 m = torch.jit.script(M()) 3119 qconfig = ( 3120 per_channel_dynamic_qconfig 3121 if is_per_channel is True 3122 else default_dynamic_qconfig 3123 ) 3124 m = quantize_dynamic_jit(m, {"": qconfig}, debug=True) 3125 assert ( 3126 len(m._modules._c.items()) == 2 3127 ), "Expected to have two submodule of linear" 3128 3129 wt_quant_func = ( 3130 "aten::quantize_per_channel" 3131 if is_per_channel 3132 else "aten::quantize_per_tensor" 3133 ) 3134 act_quant_func = "aten::quantize_per_tensor" 3135 # quantizing activations 3136 FileCheck().check("aten::_choose_qparams_per_tensor").check_next( 3137 act_quant_func 3138 ).check_next("aten::dequantize").check( 3139 "aten::_choose_qparams_per_tensor" 3140 ).check_next( 3141 act_quant_func 3142 ).check_next( 3143 "aten::dequantize" 3144 ).check( 3145 wt_quant_func 3146 ).check_next( 3147 "aten::dequantize" 3148 ).check_not( 3149 wt_quant_func 3150 ).check( 3151 "return" 3152 ).run( 3153 m.graph 3154 ) 3155 3156 @override_qengines 3157 def test_dynamic_multi_op(self): 3158 class M(torch.nn.Module): 3159 def __init__(self) -> None: 3160 super().__init__() 3161 self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) 3162 3163 def forward(self, x): 3164 x = x + 5 3165 return self.fc1(x) 3166 3167 x = torch.randn(5, 5) 3168 for tracing in [True, False]: 3169 model = self.checkGraphModeOp( 3170 M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True 3171 ) 3172 # add op is not dynamically quantized. 3173 FileCheck().check("aten::add").run(model.graph) 3174 3175 @override_qengines 3176 def test_dynamic_quant_multi_uses(self): 3177 class M(torch.nn.Module): 3178 def __init__(self) -> None: 3179 super().__init__() 3180 self.fc = torch.nn.Linear(5, 5).float() 3181 3182 def forward(self, x): 3183 size1 = x.size() 3184 size2 = x.size() 3185 return self.fc(x), size1, size2 3186 3187 x = torch.randn(5, 5) 3188 for tracing in [True, False]: 3189 model = self.checkGraphModeOp( 3190 M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True 3191 ) 3192 FileCheck().check_not("aten::_choose_qparams_per_tensor").run(model.graph) 3193 3194 @override_qengines 3195 def test_dynamic_shared_weights(self): 3196 class myMod(torch.nn.Module): 3197 def __init__(self, weight): 3198 super().__init__() 3199 self.linear = nn.Linear(5, 5) 3200 self.linear.weight = weight 3201 3202 def forward(self, x): 3203 return self.linear(x) 3204 3205 class DynamicModel(torch.nn.Module): 3206 def __init__(self) -> None: 3207 super().__init__() 3208 self.weight = torch.nn.Parameter(torch.ones(5, 5)) 3209 self.mod1 = myMod(self.weight) 3210 3211 def forward(self, x): 3212 y = self.mod1(x) 3213 z = torch.nn.functional.linear(y, self.weight) 3214 return z 3215 3216 model = torch.jit.script(DynamicModel()).eval() 3217 data = torch.randn(5, 5, dtype=torch.float) 3218 quant_ops = ["mod1", ""] 3219 counts = [1, 2] 3220 for op, count in zip(quant_ops, counts): 3221 qconfig_dict = {op: default_dynamic_qconfig} 3222 m1 = quantize_dynamic_jit(model, qconfig_dict) 3223 out_graph = m1(data) 3224 3225 FileCheck().check_count( 3226 "quantized::linear_dynamic(", count, exactly=True 3227 ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph) 3228 3229 # Explicitly call forward on model before convert 3230 m2 = prepare_dynamic_jit(model, qconfig_dict) 3231 m2(data) 3232 m2 = convert_dynamic_jit(m2, debug=False) 3233 out_ref = m2(data) 3234 self.assertEqual(out_graph, out_ref) 3235 3236 @override_qengines 3237 def test_dynamic_with_if(self): 3238 class Res(torch.nn.Module): 3239 def __init__(self) -> None: 3240 super().__init__() 3241 self.weight = torch.nn.Parameter(torch.ones(5, 5)) 3242 3243 def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: 3244 if cond: 3245 return torch.nn.functional.linear(x, self.weight) 3246 else: 3247 return torch.nn.functional.linear(x, self.weight) 3248 3249 class M(torch.nn.Module): 3250 def __init__(self) -> None: 3251 super().__init__() 3252 self.res1 = Res() 3253 self.res2 = Res() 3254 3255 def forward(self, x): 3256 x = self.res1(x, True) 3257 x = self.res2(x, False) 3258 return x 3259 3260 model = torch.jit.script(M()).eval() 3261 data = torch.randn(5, 5, dtype=torch.float) 3262 qconfig_dict = {"": default_dynamic_qconfig} 3263 for tracing in [True, False]: 3264 m1 = self.checkGraphModeOp( 3265 M(), data, "quantized::linear_dynamic", tracing=tracing, dynamic=True 3266 ) 3267 FileCheck().check_count( 3268 "quantized::linear_dynamic(", 2, exactly=True 3269 ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph) 3270 3271 # Check to make sure weight observers run correctly 3272 ref_qparams = [] 3273 qconfig = script_qconfig(default_dynamic_qconfig) 3274 wt_module = wrap_cpp_module(qconfig.weight) 3275 for wt in [model.res1.weight, model.res2.weight]: 3276 wt_module(wt) 3277 qparams = wt_module.calculate_qparams() 3278 ref_qparams.append((qparams[0].item(), qparams[1].item())) 3279 3280 m2 = quantize_dynamic_jit(model, qconfig_dict, debug=True) 3281 graph_params = [] 3282 for x, obs in m2._modules._c.items(): 3283 if x == "res1": 3284 graph_params.append( 3285 ( 3286 obs.getattr("weight.2_scale_0"), 3287 obs.getattr("weight.2_zero_point_0"), 3288 ) 3289 ) 3290 elif x == "res2": 3291 graph_params.append( 3292 ( 3293 obs.getattr("weight.4_scale_0"), 3294 obs.getattr("weight.4_zero_point_0"), 3295 ) 3296 ) 3297 self.assertEqual(ref_qparams, graph_params) 3298 3299 def test_dynamic_weight_observer(self): 3300 class M(torch.nn.Module): 3301 def __init__(self) -> None: 3302 super().__init__() 3303 self.fc = torch.nn.Linear(5, 5).float() 3304 self.fc2 = torch.nn.Linear(5, 5).float() 3305 3306 def forward(self, x): 3307 x = self.fc(x) 3308 return self.fc2(x) 3309 3310 qconfig_dict = {"": default_dynamic_qconfig} 3311 eager_model = M().eval() 3312 for tracing in [True, False]: 3313 x = torch.rand(5, 5) 3314 model = get_script_module(eager_model, tracing, x) 3315 ref_qparams = [] 3316 for wt in [model.fc.weight, model.fc2.weight]: 3317 wt_module = default_dynamic_qconfig.weight() 3318 wt_module(wt) 3319 qparams = wt_module.calculate_qparams() 3320 ref_qparams.append((qparams[0].item(), qparams[1].item())) 3321 model = quantize_dynamic_jit(model, qconfig_dict, debug=True) 3322 graph_qparams = [] 3323 for x, obs in model._modules._c.items(): 3324 n = 2 if x == "fc" and tracing else 1 3325 graph_qparams.append( 3326 ( 3327 obs.getattr(f"weight.{n}_scale_0"), 3328 obs.getattr(f"weight.{n}_zero_point_0"), 3329 ) 3330 ) 3331 self.assertEqual(ref_qparams, graph_qparams) 3332 3333 def test_convert_dynamic_fp16(self): 3334 class M(torch.nn.Module): 3335 def __init__(self) -> None: 3336 super().__init__() 3337 self.fc = torch.nn.Linear(5, 5) 3338 3339 def forward(self, x): 3340 return self.fc(x) 3341 3342 m = torch.jit.script(M()) 3343 m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}, debug=True) 3344 FileCheck().check("aten::_saturate_weight_to_fp16").check( 3345 "aten::linear" 3346 ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph) 3347 3348 def test_quantize_dynamic_fp16(self): 3349 class M(torch.nn.Module): 3350 def __init__(self) -> None: 3351 super().__init__() 3352 self.fc = torch.nn.Linear(5, 5) 3353 3354 def forward(self, x): 3355 return self.fc(x) 3356 3357 m = torch.jit.script(M()) 3358 m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}) 3359 3360 FileCheck().check("quantized::linear_dynamic_fp16").check_not( 3361 "aten::linear" 3362 ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph) 3363 3364 3365class TestQuantizeDynamicJitOps(QuantizationTestCase): 3366 """Test graph mode post training dynamic quantization works 3367 for individual ops end to end. 3368 """ 3369 3370 @override_qengines 3371 def test_linear(self): 3372 class FunctionalLinear(torch.nn.Module): 3373 def __init__(self, weight, bias): 3374 super().__init__() 3375 self.weight = weight 3376 self.bias = bias 3377 3378 def forward(self, x): 3379 return F.linear(x, self.weight, self.bias) 3380 3381 x = torch.rand(5, 5) 3382 for tracing in [True, False]: 3383 model = self.checkGraphModeOp( 3384 torch.nn.Linear(5, 5), 3385 x, 3386 "quantized::linear_dynamic", 3387 tracing=tracing, 3388 dynamic=True, 3389 ) 3390 3391 weight = torch.rand(5, 5) 3392 b = torch.rand(5) 3393 for tracing, has_bias in itertools.product([True, False], [True, False]): 3394 bias = b if has_bias else None 3395 model = self.checkGraphModeOp( 3396 FunctionalLinear(weight, bias), 3397 x, 3398 "quantized::linear_dynamic", 3399 tracing=tracing, 3400 dynamic=True, 3401 ) 3402 3403 @skipIfNoFBGEMM 3404 def test_embedding_bag(self): 3405 class M(torch.nn.Module): 3406 def __init__(self, weights): 3407 super().__init__() 3408 self.embedding1 = torch.nn.EmbeddingBag( 3409 num_embeddings=10, 3410 embedding_dim=12, 3411 include_last_offset=True, 3412 sparse=True, 3413 _weight=weights, 3414 mode="sum", 3415 ) 3416 3417 self.embedding2 = torch.nn.EmbeddingBag( 3418 num_embeddings=10, 3419 embedding_dim=12, 3420 include_last_offset=True, 3421 sparse=True, 3422 _weight=weights, 3423 mode="sum", 3424 ) 3425 3426 def forward(self, indices1, offsets1, indices2, offsets2): 3427 e1 = self.embedding1(indices1, offsets1) 3428 e2 = self.embedding2(indices2, offsets2) 3429 return e1, e2 3430 3431 weights = torch.randn(10, 12, dtype=torch.float32) 3432 module = M(weights) 3433 3434 indices = torch.tensor( 3435 [ 3436 9, 3437 6, 3438 5, 3439 7, 3440 8, 3441 8, 3442 9, 3443 2, 3444 8, 3445 6, 3446 6, 3447 9, 3448 1, 3449 6, 3450 8, 3451 8, 3452 3, 3453 2, 3454 3, 3455 6, 3456 3, 3457 6, 3458 5, 3459 7, 3460 0, 3461 8, 3462 4, 3463 6, 3464 5, 3465 8, 3466 2, 3467 3, 3468 ] 3469 ) 3470 offsets = torch.tensor([0, 19, 20, 28, 28, 32]) 3471 dummy_inputs = (indices, offsets, indices, offsets) 3472 for trace in [True, False]: 3473 if trace: 3474 m = torch.jit.trace(module, dummy_inputs) 3475 else: 3476 m = torch.jit.script(module) 3477 int4_qconfig = QConfig( 3478 activation=PlaceholderObserver.with_args( 3479 dtype=torch.float, custom_op_name="embedding_bag_4bit" 3480 ), 3481 weight=PlaceholderObserver.with_args( 3482 custom_op_name="embedding_bag_4bit" 3483 ), 3484 ) 3485 int8_qconfig = QConfig( 3486 activation=PlaceholderObserver.with_args( 3487 dtype=torch.float, custom_op_name="embedding_bag_byte" 3488 ), 3489 weight=PlaceholderObserver.with_args( 3490 custom_op_name="embedding_bag_byte" 3491 ), 3492 ) 3493 m = prepare_jit(m, {"embedding1": int4_qconfig, "embedding2": int8_qconfig}) 3494 m = convert_jit(m) 3495 FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets").check( 3496 "quantized::embedding_bag_byte_rowwise_offsets" 3497 ).run(m.graph) 3498 m(*dummy_inputs) 3499 3500 # Ensure that attempting to quantize an EmbeddingBag throws an error if 3501 # padding_idx is not None 3502 @skipIfNoFBGEMM 3503 def test_embedding_bag_padding_idx_error(self): 3504 class M(torch.nn.Module): 3505 def __init__(self, weights): 3506 super().__init__() 3507 self.embedding = torch.nn.EmbeddingBag( 3508 num_embeddings=10, 3509 embedding_dim=12, 3510 include_last_offset=True, 3511 sparse=True, 3512 _weight=weights, 3513 mode="sum", 3514 padding_idx=0, 3515 ) 3516 3517 def forward(self, indices, offsets): 3518 e = self.embedding(indices, offsets) 3519 return e 3520 3521 weights = torch.randn(10, 12, dtype=torch.float32) 3522 module = M(weights) 3523 3524 indices = torch.tensor([0, 1, 2, 3, 4]) 3525 offsets = torch.tensor([0, 2, 5]) 3526 dummy_inputs = (indices, offsets) 3527 3528 int4_qconfig = QConfig( 3529 activation=PlaceholderObserver.with_args( 3530 dtype=torch.float, custom_op_name="embedding_bag_4bit" 3531 ), 3532 weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_4bit"), 3533 ) 3534 int8_qconfig = QConfig( 3535 activation=PlaceholderObserver.with_args( 3536 dtype=torch.float, custom_op_name="embedding_bag_byte" 3537 ), 3538 weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"), 3539 ) 3540 3541 error_msg = r"Expected aten::embedding_bag padding_idx input to be None" 3542 for trace, qconfig in itertools.product( 3543 [True, False], [int4_qconfig, int8_qconfig] 3544 ): 3545 if trace: 3546 m = torch.jit.trace(module, dummy_inputs) 3547 else: 3548 m = torch.jit.script(module) 3549 m = prepare_jit(m, {"embedding": qconfig}) 3550 with self.assertRaisesRegex(RuntimeError, error_msg): 3551 m = convert_jit(m) 3552 3553 3554class TestQuantizeJit(QuantizationTestCase): 3555 @override_qengines 3556 def test_single_linear(self): 3557 r"""Compare the result of quantizing single linear layer in 3558 eager mode and graph mode 3559 """ 3560 # eager mode 3561 annotated_linear_model = AnnotatedSingleLayerLinearModel( 3562 torch.backends.quantized.engine 3563 ).eval() 3564 linear_model = SingleLayerLinearModel().eval() 3565 # copy the weight from eager mode so that we can 3566 # compare the result of the two quantized models later 3567 linear_model.fc1.weight = torch.nn.Parameter( 3568 annotated_linear_model.fc1.module.weight.detach() 3569 ) 3570 linear_model.fc1.bias = torch.nn.Parameter( 3571 annotated_linear_model.fc1.module.bias.detach() 3572 ) 3573 model_eager = quantize( 3574 annotated_linear_model, test_only_eval_fn, [self.calib_data] 3575 ) 3576 3577 qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} 3578 model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) 3579 model_script = torch.jit.script(linear_model) 3580 result_eager = model_eager(self.calib_data[0][0]) 3581 for model_under_test in [model_traced, model_script]: 3582 model_quantized = quantize_jit( 3583 model_under_test, 3584 qconfig_dict, 3585 test_only_eval_fn, 3586 [self.calib_data], 3587 inplace=False, 3588 ) 3589 self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) 3590 3591 @skipIfNoFBGEMM 3592 def test_observer_with_ignored_function(self): 3593 r"""Test observers with ignored function and make sure it works in 3594 graph mode 3595 """ 3596 # eager mode 3597 annotated_linear_model = AnnotatedSingleLayerLinearModel("fbgemm").eval() 3598 for qconfig in [ 3599 QConfig(activation=default_observer, weight=default_weight_observer), 3600 QConfig( 3601 activation=default_histogram_observer, weight=default_weight_observer 3602 ), 3603 QConfig( 3604 activation=default_observer, weight=default_per_channel_weight_observer 3605 ), 3606 ]: 3607 annotated_linear_model.qconfig = qconfig 3608 linear_model = SingleLayerLinearModel().eval() 3609 # copy the weight from eager mode so that we can 3610 # compare the result of the two quantized models later 3611 linear_model.fc1.weight = torch.nn.Parameter( 3612 annotated_linear_model.fc1.module.weight.detach() 3613 ) 3614 linear_model.fc1.bias = torch.nn.Parameter( 3615 annotated_linear_model.fc1.module.bias.detach() 3616 ) 3617 model_eager = quantize( 3618 annotated_linear_model, test_only_eval_fn, [self.calib_data] 3619 ) 3620 3621 qconfig_dict = {"": qconfig} 3622 model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) 3623 model_script = torch.jit.script(linear_model) 3624 result_eager = model_eager(self.calib_data[0][0]) 3625 for model_under_test in [model_traced, model_script]: 3626 model_quantized = quantize_jit( 3627 model_under_test, 3628 qconfig_dict, 3629 test_only_eval_fn, 3630 [self.calib_data], 3631 inplace=False, 3632 ) 3633 self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) 3634 3635 @override_qengines 3636 def test_conv(self): 3637 r"""Compare the result of quantizing conv layer in 3638 eager mode and graph mode 3639 """ 3640 # eager mode 3641 annotated_conv_model = AnnotatedConvModel( 3642 torch.backends.quantized.engine 3643 ).eval() 3644 conv_model = ConvModel().eval() 3645 # copy the weight from eager mode so that we can 3646 # compare the result of the two quantized models later 3647 conv_model.conv.weight = torch.nn.Parameter( 3648 annotated_conv_model.conv.weight.detach() 3649 ) 3650 model_eager = quantize( 3651 annotated_conv_model, test_only_eval_fn, [self.img_data_2d] 3652 ) 3653 qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} 3654 model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) 3655 model_script = torch.jit.script(conv_model) 3656 result_eager = model_eager(self.img_data_2d[0][0]) 3657 for model_under_test in [model_traced, model_script]: 3658 model_quantized = quantize_jit( 3659 model_under_test, 3660 qconfig_dict, 3661 test_only_eval_fn, 3662 [self.img_data_2d], 3663 inplace=False, 3664 ) 3665 self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager) 3666 3667 @override_qengines 3668 def test_conv_transpose(self): 3669 r"""Compare the result of quantizing conv_transpose layer in 3670 eager mode and graph mode 3671 """ 3672 if not qengine_is_qnnpack(): 3673 return # Currently only qnnpack is supported 3674 # eager mode 3675 annotated_conv_model = AnnotatedConvTransposeModel( 3676 torch.backends.quantized.engine 3677 ).eval() 3678 conv_model = ConvTransposeModel().eval() 3679 # copy the weight from eager mode so that we can 3680 # compare the result of the two quantized models later 3681 conv_model.conv.weight = torch.nn.Parameter( 3682 annotated_conv_model.conv.weight.detach() 3683 ) 3684 model_eager = quantize( 3685 annotated_conv_model, test_only_eval_fn, [self.img_data_2d] 3686 ) 3687 qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} 3688 model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) 3689 model_script = torch.jit.script(conv_model) 3690 result_eager = model_eager(self.img_data_2d[0][0]) 3691 for model_under_test in [model_traced, model_script]: 3692 model_quantized = quantize_jit( 3693 model_under_test, 3694 qconfig_dict, 3695 test_only_eval_fn, 3696 [self.img_data_2d], 3697 inplace=False, 3698 ) 3699 self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager) 3700 3701 @override_qengines 3702 def test_conv_bn(self): 3703 r"""Compare the result of quantizing conv + bn layer in 3704 eager mode and graph mode 3705 """ 3706 # eager mode 3707 conv_model = AnnotatedConvBnModel().eval() 3708 conv_model_to_script = ConvBnModel().eval() 3709 # copy the weight from eager mode so that we can 3710 # compare the result of the two quantized models later 3711 conv_model_to_script.conv.weight = torch.nn.Parameter( 3712 conv_model.conv.weight.detach() 3713 ) 3714 fuse_modules(conv_model, ["conv", "bn"], inplace=True) 3715 model_eager = quantize(conv_model, test_only_eval_fn, [self.img_data_2d]) 3716 qconfig_dict = {"": default_qconfig} 3717 model_script = quantize_jit( 3718 torch.jit.script(conv_model_to_script), 3719 qconfig_dict, 3720 test_only_eval_fn, 3721 [self.img_data_2d], 3722 inplace=False, 3723 ) 3724 result_eager = model_eager(self.img_data_2d[0][0]) 3725 result_script = model_script(self.img_data_2d[0][0]) 3726 self.assertEqual(result_eager, result_script) 3727 3728 @override_qengines 3729 def test_nested(self): 3730 # Eager mode 3731 eager_model = AnnotatedNestedModel(torch.backends.quantized.engine).eval() 3732 3733 # Graph mode 3734 script_model = NestedModel().eval() 3735 # Copy weights for eager_model 3736 script_model.sub1.fc.weight = torch.nn.Parameter( 3737 eager_model.sub1.fc.weight.detach() 3738 ) 3739 script_model.sub1.fc.bias = torch.nn.Parameter( 3740 eager_model.sub1.fc.bias.detach() 3741 ) 3742 script_model.sub2.fc1.weight = torch.nn.Parameter( 3743 eager_model.sub2.fc1.module.weight.detach() 3744 ) 3745 script_model.sub2.fc1.bias = torch.nn.Parameter( 3746 eager_model.sub2.fc1.module.bias.detach() 3747 ) 3748 script_model.sub2.fc2.weight = torch.nn.Parameter( 3749 eager_model.sub2.fc2.weight.detach() 3750 ) 3751 script_model.sub2.fc2.bias = torch.nn.Parameter( 3752 eager_model.sub2.fc2.bias.detach() 3753 ) 3754 script_model.fc3.weight = torch.nn.Parameter( 3755 eager_model.fc3.module.weight.detach() 3756 ) 3757 script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach()) 3758 3759 model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) 3760 qconfig_dict = { 3761 "sub2.fc1": default_per_channel_qconfig 3762 if qengine_is_fbgemm() 3763 else default_qconfig, 3764 "fc3": default_qconfig, 3765 } 3766 model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) 3767 model_script = torch.jit.script(script_model) 3768 result_eager = model_eager(self.calib_data[0][0]) 3769 for model_under_test in [model_traced, model_script]: 3770 model_quantized = quantize_jit( 3771 model_under_test, 3772 qconfig_dict, 3773 test_only_eval_fn, 3774 [self.calib_data], 3775 inplace=False, 3776 ) 3777 self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) 3778 3779 @override_qengines 3780 def test_skip_quant(self): 3781 """Test None qconfig""" 3782 # Eager mode 3783 eager_model = AnnotatedSkipQuantModel(torch.backends.quantized.engine).eval() 3784 3785 # Graph mode 3786 script_model = SkipQuantModel().eval() 3787 # Copy weights for eager_model 3788 script_model.sub.fc1.weight = torch.nn.Parameter( 3789 eager_model.sub.module.fc1.weight.detach() 3790 ) 3791 script_model.sub.fc1.bias = torch.nn.Parameter( 3792 eager_model.sub.module.fc1.bias.detach() 3793 ) 3794 script_model.sub.fc2.weight = torch.nn.Parameter( 3795 eager_model.sub.module.fc2.weight.detach() 3796 ) 3797 script_model.sub.fc2.bias = torch.nn.Parameter( 3798 eager_model.sub.module.fc2.bias.detach() 3799 ) 3800 script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach()) 3801 script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach()) 3802 3803 eager_model.fuse_modules() 3804 3805 model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) 3806 qconfig_dict = { 3807 "": get_default_qconfig(torch.backends.quantized.engine), 3808 "fc": None, 3809 } 3810 model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) 3811 model_script = torch.jit.script(script_model) 3812 result_eager = model_eager(self.calib_data[0][0]) 3813 for model_under_test in [model_traced, model_script]: 3814 model_quantized = quantize_jit( 3815 model_under_test, 3816 qconfig_dict, 3817 test_only_eval_fn, 3818 [self.calib_data], 3819 inplace=False, 3820 ) 3821 self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) 3822 3823 @override_qengines 3824 def test_single_linear_dynamic(self): 3825 r"""Compare the result of dynamic quantization of single linear layer in 3826 eager mode and graph mode. 3827 """ 3828 if qengine_is_qnnpack(): 3829 # eager mode 3830 annotated_linear_model = AnnotatedSingleLayerLinearModel("qnnpack").eval() 3831 linear_model = SingleLayerLinearModel().eval() 3832 # copy the weight from eager mode so that we can 3833 # compare the result of the two quantized models later 3834 linear_model.fc1.weight = torch.nn.Parameter( 3835 annotated_linear_model.fc1.module.weight.detach() 3836 ) 3837 linear_model.fc1.bias = torch.nn.Parameter( 3838 annotated_linear_model.fc1.module.bias.detach() 3839 ) 3840 qconfig_dict = {"": default_dynamic_qconfig} 3841 model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict) 3842 3843 model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) 3844 model_script = torch.jit.script(linear_model) 3845 result_eager = model_eager(self.calib_data[0][0]) 3846 3847 for model_under_test in [model_traced, model_script]: 3848 model_quantized = quantize_dynamic_jit(model_under_test, qconfig_dict) 3849 self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) 3850 3851 # Check to make sure choose_qparams->quant->dequant->linear is numerically 3852 # equivalent to the final quantized model. 3853 model_fake_quantized = quantize_dynamic_jit( 3854 model_under_test, qconfig_dict, debug=True 3855 ) 3856 self.assertEqual( 3857 model_fake_quantized(self.calib_data[0][0]), result_eager 3858 ) 3859 3860 @skipIfNoFBGEMM 3861 def test_linear_dynamic_fp16(self): 3862 linear_model = SingleLayerLinearModel().eval() 3863 # Create weight tensor values that are beyond fp16 max 3864 x = torch.ones(5, 5) * 65532 3865 linear_model.fc1.weight = torch.nn.Parameter(x) 3866 import warnings 3867 3868 model_eager = quantize_dynamic(linear_model, dtype=torch.float16) 3869 result_eager = model_eager(self.calib_data[0][0]) 3870 for trace in [True]: 3871 with warnings.catch_warnings(record=True) as w: 3872 quantized_model = self.checkGraphModeOp( 3873 linear_model, 3874 self.calib_data[0][0], 3875 "quantized::linear_dynamic_fp16", 3876 tracing=trace, 3877 dynamic=True, 3878 qconfig=float16_dynamic_qconfig, 3879 ) 3880 # compare result with eager mode 3881 self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager) 3882