1# Owner(s): ["oncall: mobile"] 2 3import io 4import itertools 5import unittest 6 7from hypothesis import assume, given, strategies as st 8 9import torch 10import torch.backends.xnnpack 11import torch.testing._internal.hypothesis_utils as hu 12from torch.nn import functional as F 13from torch.testing import FileCheck 14from torch.testing._internal.common_utils import ( 15 IS_FBCODE, 16 run_tests, 17 slowTest, 18 TEST_WITH_TSAN, 19 TestCase, 20) 21from torch.utils.mobile_optimizer import optimize_for_mobile 22 23 24@unittest.skipUnless( 25 torch.backends.xnnpack.enabled, 26 " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 27) 28@unittest.skipIf( 29 TEST_WITH_TSAN, 30 "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 31) 32class TestXNNPACKOps(TestCase): 33 @unittest.skip( 34 "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 35 ) 36 @given( 37 batch_size=st.integers(0, 3), 38 data_shape=hu.array_shapes(1, 3, 2, 64), 39 weight_output_dim=st.integers(2, 64), 40 use_bias=st.booleans(), 41 ) 42 def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): 43 data_shape = [batch_size] + list(data_shape) 44 input_data = torch.rand(data_shape) 45 weight = torch.rand((weight_output_dim, data_shape[-1])) 46 if use_bias: 47 bias = torch.rand(weight_output_dim) 48 else: 49 bias = None 50 ref_result = F.linear(input_data, weight, bias) 51 packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) 52 output_linearprepacked = torch.ops.prepacked.linear_clamp_run( 53 input_data, packed_weight_bias 54 ) 55 torch.testing.assert_close( 56 ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 57 ) 58 59 @given( 60 input_size=st.integers(2, 32), 61 weight_output_dim=st.integers(2, 64), 62 use_bias=st.booleans(), 63 ) 64 def test_linear_1d_input(self, input_size, weight_output_dim, use_bias): 65 input_data = torch.rand(input_size) 66 weight = torch.rand((weight_output_dim, input_data.shape[-1])) 67 if use_bias: 68 bias = torch.rand(weight_output_dim) 69 else: 70 bias = None 71 ref_result = F.linear(input_data, weight, bias) 72 packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) 73 output_linearprepacked = torch.ops.prepacked.linear_clamp_run( 74 input_data, packed_weight_bias 75 ) 76 torch.testing.assert_close( 77 ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 78 ) 79 80 @given( 81 batch_size=st.integers(0, 3), 82 input_channels_per_group=st.integers(1, 32), 83 height=st.integers(5, 64), 84 width=st.integers(5, 64), 85 output_channels_per_group=st.integers(1, 32), 86 groups=st.integers(1, 16), 87 kernel_h=st.integers(1, 7), 88 kernel_w=st.integers(1, 7), 89 stride_h=st.integers(1, 2), 90 stride_w=st.integers(1, 2), 91 pad_h=st.integers(0, 2), 92 pad_w=st.integers(0, 2), 93 dilation=st.integers(1, 2), 94 use_bias=st.booleans(), 95 format=st.sampled_from( 96 [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 97 ), 98 ) 99 def test_conv2d( 100 self, 101 batch_size, 102 input_channels_per_group, 103 height, 104 width, 105 output_channels_per_group, 106 groups, 107 kernel_h, 108 kernel_w, 109 stride_h, 110 stride_w, 111 pad_h, 112 pad_w, 113 dilation, 114 use_bias, 115 format, 116 ): 117 input_channels = input_channels_per_group * groups 118 output_channels = output_channels_per_group * groups 119 kernels = (kernel_h, kernel_w) 120 strides = (stride_h, stride_w) 121 paddings = (pad_h, pad_w) 122 dilations = (dilation, dilation) 123 assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 124 assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 125 126 input_data = torch.rand((batch_size, input_channels, height, width)) 127 if format is not None: 128 input_data = input_data.contiguous(memory_format=format) 129 weight = torch.rand( 130 (output_channels, input_channels_per_group, kernel_h, kernel_w) 131 ) 132 bias = None 133 if use_bias: 134 bias = torch.rand(output_channels) 135 136 ref_result = F.conv2d( 137 input_data, weight, bias, strides, paddings, dilations, groups 138 ) 139 packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( 140 weight, bias, strides, paddings, dilations, groups 141 ) 142 xnnpack_result = torch.ops.prepacked.conv2d_clamp_run( 143 input_data, packed_weight_bias 144 ) 145 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 146 147 @given( 148 batch_size=st.integers(1, 3), 149 input_channels_per_group=st.integers(1, 32), 150 height=st.integers(5, 64), 151 width=st.integers(5, 64), 152 output_channels_per_group=st.integers(1, 32), 153 groups=st.integers(1, 16), 154 kernel_h=st.integers(1, 7), 155 kernel_w=st.integers(1, 7), 156 stride_h=st.integers(1, 2), 157 stride_w=st.integers(1, 2), 158 pad_h=st.integers(0, 2), 159 pad_w=st.integers(0, 2), 160 output_pad_h=st.integers(0, 2), 161 output_pad_w=st.integers(0, 2), 162 dilation=st.integers(1, 2), 163 use_bias=st.booleans(), 164 format=st.sampled_from( 165 [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 166 ), 167 ) 168 def test_conv2d_transpose( 169 self, 170 batch_size, 171 input_channels_per_group, 172 height, 173 width, 174 output_channels_per_group, 175 groups, 176 kernel_h, 177 kernel_w, 178 stride_h, 179 stride_w, 180 pad_h, 181 pad_w, 182 output_pad_h, 183 output_pad_w, 184 dilation, 185 use_bias, 186 format, 187 ): 188 input_channels = input_channels_per_group * groups 189 output_channels = output_channels_per_group * groups 190 kernels = (kernel_h, kernel_w) 191 strides = (stride_h, stride_w) 192 paddings = (pad_h, pad_w) 193 output_paddings = (output_pad_h, output_pad_w) 194 dilations = (dilation, dilation) 195 assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 196 assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 197 assume((output_pad_h < stride_h) and (output_pad_h < dilation)) 198 assume((output_pad_w < stride_w) and (output_pad_w < dilation)) 199 200 input_data = torch.rand((batch_size, input_channels, height, width)) 201 if format is not None: 202 input_data = input_data.contiguous(memory_format=format) 203 weight = torch.rand( 204 (input_channels, output_channels_per_group, kernel_h, kernel_w) 205 ) 206 bias = None 207 if use_bias: 208 bias = torch.rand(output_channels) 209 210 # Note that groups/dilation is in reverse order from conv2d 211 ref_result = F.conv_transpose2d( 212 input_data, 213 weight, 214 bias, 215 strides, 216 paddings, 217 output_paddings, 218 groups, 219 dilation, 220 ) 221 packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack( 222 weight, bias, strides, paddings, output_paddings, dilations, groups 223 ) 224 xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run( 225 input_data, packed_weight_bias 226 ) 227 torch.testing.assert_close( 228 ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3 229 ) 230 231 232@unittest.skipUnless( 233 torch.backends.xnnpack.enabled, 234 " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 235) 236@unittest.skipIf( 237 TEST_WITH_TSAN, 238 "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 239) 240class TestXNNPACKSerDes(TestCase): 241 @unittest.skip( 242 "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 243 ) 244 @given( 245 batch_size=st.integers(0, 3), 246 data_shape=hu.array_shapes(1, 3, 2, 64), 247 weight_output_dim=st.integers(2, 64), 248 use_bias=st.booleans(), 249 ) 250 def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): 251 class Linear(torch.nn.Module): 252 def __init__(self, weight, bias=None): 253 super().__init__() 254 self.weight = weight 255 self.bias = bias 256 257 def forward(self, x): 258 return F.linear(x, self.weight, self.bias) 259 260 class LinearPrePacked(torch.nn.Module): 261 def __init__(self, weight, bias=None): 262 super().__init__() 263 self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack( 264 weight, bias 265 ) 266 267 def forward(self, x): 268 return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias) 269 270 data_shape = [batch_size] + list(data_shape) 271 weight = torch.rand((weight_output_dim, data_shape[-1])) 272 if use_bias: 273 bias = torch.rand(weight_output_dim) 274 else: 275 bias = None 276 scripted_linear = torch.jit.script(Linear(weight, bias)) 277 scripted_linear_clamp_prepacked = torch.jit.script( 278 LinearPrePacked(weight, bias) 279 ) 280 input_data = torch.rand(data_shape) 281 ref_result = scripted_linear(input_data) 282 output_linearprepacked = scripted_linear_clamp_prepacked(input_data) 283 torch.testing.assert_close( 284 ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 285 ) 286 287 # Serialize the modules and then deserialize 288 input_data = torch.rand(data_shape) 289 buffer = io.BytesIO() 290 torch.jit.save(scripted_linear, buffer) 291 buffer.seek(0) 292 deserialized_linear = torch.jit.load(buffer) 293 buffer = io.BytesIO() 294 torch.jit.save(scripted_linear_clamp_prepacked, buffer) 295 buffer.seek(0) 296 deserialized_linear_clamp_prepacked = torch.jit.load(buffer) 297 ref_result = deserialized_linear(input_data) 298 output_linearprepacked = deserialized_linear_clamp_prepacked(input_data) 299 torch.testing.assert_close( 300 ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 301 ) 302 303 @given( 304 batch_size=st.integers(0, 3), 305 input_channels_per_group=st.integers(1, 32), 306 height=st.integers(5, 64), 307 width=st.integers(5, 64), 308 output_channels_per_group=st.integers(1, 32), 309 groups=st.integers(1, 16), 310 kernel_h=st.integers(1, 7), 311 kernel_w=st.integers(1, 7), 312 stride_h=st.integers(1, 2), 313 stride_w=st.integers(1, 2), 314 pad_h=st.integers(0, 2), 315 pad_w=st.integers(0, 2), 316 dilation=st.integers(1, 2), 317 use_bias=st.booleans(), 318 format=st.sampled_from( 319 [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 320 ), 321 ) 322 def test_conv2d( 323 self, 324 batch_size, 325 input_channels_per_group, 326 height, 327 width, 328 output_channels_per_group, 329 groups, 330 kernel_h, 331 kernel_w, 332 stride_h, 333 stride_w, 334 pad_h, 335 pad_w, 336 dilation, 337 use_bias, 338 format, 339 ): 340 class Conv2D(torch.nn.Module): 341 def __init__(self, weight, bias, strides, paddings, dilations, groups): 342 super().__init__() 343 self.weight = weight 344 self.bias = bias 345 self.strides = strides 346 self.paddings = paddings 347 self.dilations = dilations 348 self.groups = groups 349 350 def forward(self, x): 351 return F.conv2d( 352 x, 353 self.weight, 354 self.bias, 355 self.strides, 356 self.paddings, 357 self.dilations, 358 self.groups, 359 ) 360 361 class Conv2DPrePacked(torch.nn.Module): 362 def __init__(self, weight, bias, strides, paddings, dilations, groups): 363 super().__init__() 364 self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( 365 weight, bias, strides, paddings, dilations, groups 366 ) 367 368 def forward(self, x): 369 return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias) 370 371 input_channels = input_channels_per_group * groups 372 output_channels = output_channels_per_group * groups 373 kernels = (kernel_h, kernel_w) 374 strides = (stride_h, stride_w) 375 paddings = (pad_h, pad_w) 376 dilations = (dilation, dilation) 377 assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 378 assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 379 380 input_data = torch.rand((batch_size, input_channels, height, width)) 381 if format is not None: 382 input_data = input_data.contiguous(memory_format=format) 383 weight = torch.rand( 384 (output_channels, input_channels_per_group, kernel_h, kernel_w) 385 ) 386 bias = None 387 if use_bias: 388 bias = torch.rand(output_channels) 389 390 scripted_conv2d = torch.jit.script( 391 Conv2D(weight, bias, strides, paddings, dilations, groups) 392 ) 393 scripted_conv2d_clamp_prepacked = torch.jit.script( 394 Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups) 395 ) 396 ref_result = scripted_conv2d(input_data) 397 xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) 398 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 399 400 # Serialize the modules and then deserialize 401 input_data = torch.rand((batch_size, input_channels, height, width)) 402 if format is not None: 403 input_data = input_data.contiguous(memory_format=format) 404 buffer = io.BytesIO() 405 torch.jit.save(scripted_conv2d, buffer) 406 buffer.seek(0) 407 deserialized_conv2d = torch.jit.load(buffer) 408 buffer = io.BytesIO() 409 torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) 410 buffer.seek(0) 411 deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) 412 ref_result = deserialized_conv2d(input_data) 413 xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) 414 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 415 416 @given( 417 batch_size=st.integers(0, 3), 418 input_channels_per_group=st.integers(1, 32), 419 height=st.integers(5, 64), 420 width=st.integers(5, 64), 421 output_channels_per_group=st.integers(1, 32), 422 groups=st.integers(1, 16), 423 kernel_h=st.integers(1, 7), 424 kernel_w=st.integers(1, 7), 425 stride_h=st.integers(1, 2), 426 stride_w=st.integers(1, 2), 427 pad_h=st.integers(0, 2), 428 pad_w=st.integers(0, 2), 429 output_pad_h=st.integers(0, 2), 430 output_pad_w=st.integers(0, 2), 431 dilation=st.integers(1, 2), 432 use_bias=st.booleans(), 433 format=st.sampled_from( 434 [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 435 ), 436 ) 437 def test_conv2d_transpose( 438 self, 439 batch_size, 440 input_channels_per_group, 441 height, 442 width, 443 output_channels_per_group, 444 groups, 445 kernel_h, 446 kernel_w, 447 stride_h, 448 stride_w, 449 pad_h, 450 pad_w, 451 output_pad_h, 452 output_pad_w, 453 dilation, 454 use_bias, 455 format, 456 ): 457 class Conv2DT(torch.nn.Module): 458 def __init__( 459 self, 460 weight, 461 bias, 462 strides, 463 paddings, 464 output_paddings, 465 dilations, 466 groups, 467 ): 468 super().__init__() 469 self.weight = weight 470 self.bias = bias 471 self.strides = strides 472 self.paddings = paddings 473 self.output_paddings = output_paddings 474 self.dilations = dilations 475 self.groups = groups 476 477 def forward(self, x): 478 return F.conv_transpose2d( 479 x, 480 self.weight, 481 self.bias, 482 self.strides, 483 self.paddings, 484 self.output_paddings, 485 self.groups, 486 self.dilations, 487 ) 488 489 class Conv2DTPrePacked(torch.nn.Module): 490 def __init__( 491 self, 492 weight, 493 bias, 494 strides, 495 paddings, 496 output_paddings, 497 dilations, 498 groups, 499 ): 500 super().__init__() 501 self.packed_weight_bias = ( 502 torch.ops.prepacked.conv2d_transpose_clamp_prepack( 503 weight, 504 bias, 505 strides, 506 paddings, 507 output_paddings, 508 dilations, 509 groups, 510 ) 511 ) 512 513 def forward(self, x): 514 return torch.ops.prepacked.conv2d_transpose_clamp_run( 515 x, self.packed_weight_bias 516 ) 517 518 input_channels = input_channels_per_group * groups 519 output_channels = output_channels_per_group * groups 520 kernels = (kernel_h, kernel_w) 521 strides = (stride_h, stride_w) 522 paddings = (pad_h, pad_w) 523 output_paddings = (output_pad_h, output_pad_w) 524 dilations = (dilation, dilation) 525 assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 526 assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 527 assume((output_pad_h < stride_h) and (output_pad_h < dilation)) 528 assume((output_pad_w < stride_w) and (output_pad_w < dilation)) 529 530 input_data = torch.rand((batch_size, input_channels, height, width)) 531 if format is not None: 532 input_data = input_data.contiguous(memory_format=format) 533 weight = torch.rand( 534 (input_channels, output_channels_per_group, kernel_h, kernel_w) 535 ) 536 bias = None 537 if use_bias: 538 bias = torch.rand(output_channels) 539 540 scripted_conv2d = torch.jit.script( 541 Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups) 542 ) 543 scripted_conv2d_clamp_prepacked = torch.jit.script( 544 Conv2DTPrePacked( 545 weight, bias, strides, paddings, output_paddings, dilations, groups 546 ) 547 ) 548 ref_result = scripted_conv2d(input_data) 549 xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) 550 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 551 552 # Serialize the modules and then deserialize 553 input_data = torch.rand((batch_size, input_channels, height, width)) 554 if format is not None: 555 input_data = input_data.contiguous(memory_format=format) 556 buffer = io.BytesIO() 557 torch.jit.save(scripted_conv2d, buffer) 558 buffer.seek(0) 559 deserialized_conv2d = torch.jit.load(buffer) 560 buffer = io.BytesIO() 561 torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) 562 buffer.seek(0) 563 deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) 564 ref_result = deserialized_conv2d(input_data) 565 xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) 566 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 567 568 @unittest.skip( 569 "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 570 ) 571 @given( 572 batch_size=st.integers(0, 3), 573 input_channels_per_group=st.integers(1, 32), 574 height=st.integers(5, 64), 575 width=st.integers(5, 64), 576 output_channels_per_group=st.integers(1, 32), 577 groups=st.integers(1, 16), 578 kernel_h=st.integers(1, 7), 579 kernel_w=st.integers(1, 7), 580 stride_h=st.integers(1, 2), 581 stride_w=st.integers(1, 2), 582 pad_h=st.integers(0, 2), 583 pad_w=st.integers(0, 2), 584 dilation=st.integers(1, 2), 585 linear_weight_output_dim=st.integers(2, 64), 586 use_bias=st.booleans(), 587 format=st.sampled_from( 588 [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 589 ), 590 ) 591 def test_combined_model( 592 self, 593 batch_size, 594 input_channels_per_group, 595 height, 596 width, 597 output_channels_per_group, 598 groups, 599 kernel_h, 600 kernel_w, 601 stride_h, 602 stride_w, 603 pad_h, 604 pad_w, 605 dilation, 606 linear_weight_output_dim, 607 use_bias, 608 format, 609 ): 610 class M(torch.nn.Module): 611 def __init__( 612 self, 613 conv_weight, 614 conv_bias, 615 linear_weight, 616 linear_bias, 617 strides, 618 paddings, 619 dilations, 620 groups, 621 ): 622 super().__init__() 623 self.conv_weight = conv_weight 624 self.conv_bias = conv_bias 625 self.linear_weight = linear_weight 626 self.linear_bias = linear_bias 627 self.strides = strides 628 self.paddings = paddings 629 self.dilations = dilations 630 self.groups = groups 631 632 def forward(self, x): 633 o = F.conv2d( 634 x, 635 self.conv_weight, 636 self.conv_bias, 637 self.strides, 638 self.paddings, 639 self.dilations, 640 self.groups, 641 ) 642 o = o.permute([0, 2, 3, 1]) 643 o = F.linear(o, self.linear_weight, self.linear_bias) 644 return F.relu(o) 645 646 class MPrePacked(torch.nn.Module): 647 def __init__( 648 self, 649 conv_weight, 650 conv_bias, 651 linear_weight, 652 linear_bias, 653 strides, 654 paddings, 655 dilations, 656 groups, 657 ): 658 super().__init__() 659 self.conv2d_clamp_run_weight_bias = ( 660 torch.ops.prepacked.conv2d_clamp_prepack( 661 conv_weight, conv_bias, strides, paddings, dilations, groups 662 ) 663 ) 664 self.linear_clamp_run_weight_bias = ( 665 torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias) 666 ) 667 668 def forward(self, x): 669 o = torch.ops.prepacked.conv2d_clamp_run( 670 x, self.conv2d_clamp_run_weight_bias 671 ) 672 o = o.permute([0, 2, 3, 1]) 673 o = torch.ops.prepacked.linear_clamp_run( 674 o, self.linear_clamp_run_weight_bias 675 ) 676 return F.relu(o) 677 678 input_channels = input_channels_per_group * groups 679 output_channels = output_channels_per_group * groups 680 kernels = (kernel_h, kernel_w) 681 strides = (stride_h, stride_w) 682 paddings = (pad_h, pad_w) 683 dilations = (dilation, dilation) 684 assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 685 assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 686 687 input_data = torch.rand((batch_size, input_channels, height, width)) 688 if format is not None: 689 input_data = input_data.contiguous(memory_format=format) 690 conv_weight = torch.rand( 691 (output_channels, input_channels_per_group, kernel_h, kernel_w) 692 ) 693 conv_bias = None 694 if use_bias: 695 conv_bias = torch.rand(output_channels) 696 697 # This is done just to find the output shape of the result 698 # so that the shape of weight for the following linear layer 699 # can be determined. 700 result = F.conv2d( 701 input_data, conv_weight, conv_bias, strides, paddings, dilations, groups 702 ) 703 linear_input_shape = result.shape[1] 704 705 linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape)) 706 linear_bias = None 707 if use_bias: 708 linear_bias = torch.rand(linear_weight_output_dim) 709 710 scripted_m = torch.jit.script( 711 M( 712 conv_weight, 713 conv_bias, 714 linear_weight, 715 linear_bias, 716 strides, 717 paddings, 718 dilations, 719 groups, 720 ) 721 ) 722 scripted_m_prepacked = torch.jit.script( 723 MPrePacked( 724 conv_weight, 725 conv_bias, 726 linear_weight, 727 linear_bias, 728 strides, 729 paddings, 730 dilations, 731 groups, 732 ) 733 ) 734 ref_result = scripted_m(input_data) 735 xnnpack_result = scripted_m_prepacked(input_data) 736 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 737 738 # Serialize the modules and then deserialize 739 input_data = torch.rand((batch_size, input_channels, height, width)) 740 input_data = input_data.contiguous(memory_format=torch.channels_last) 741 buffer = io.BytesIO() 742 torch.jit.save(scripted_m, buffer) 743 buffer.seek(0) 744 deserialized_m = torch.jit.load(buffer) 745 buffer = io.BytesIO() 746 torch.jit.save(scripted_m_prepacked, buffer) 747 buffer.seek(0) 748 deserialized_m_prepacked = torch.jit.load(buffer) 749 ref_result = deserialized_m(input_data) 750 xnnpack_result = deserialized_m_prepacked(input_data) 751 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 752 753 754@unittest.skipUnless( 755 torch.backends.xnnpack.enabled, 756 " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 757) 758@unittest.skipIf( 759 TEST_WITH_TSAN, 760 "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 761) 762class TestXNNPACKRewritePass(TestCase): 763 @staticmethod 764 def validate_transformed_module( 765 # To please flake 766 self, 767 pattern_count_map, 768 data_shape, 769 prepack_removal=False, 770 fuse_clamping_ops=False, 771 ): 772 input_data = torch.normal(1, 20, size=data_shape) 773 774 for jit_method in ["script", "trace"]: 775 module_instance = self 776 if jit_method == "script": 777 scripted_model = torch.jit.script(module_instance) 778 else: 779 scripted_model = torch.jit.trace(module_instance, input_data) 780 scripted_model.eval() 781 ref_result = scripted_model(input_data) 782 torch._C._jit_pass_insert_prepacked_ops(scripted_model._c) 783 if fuse_clamping_ops or prepack_removal: 784 scripted_model._c = torch._C._freeze_module(scripted_model._c) 785 if fuse_clamping_ops: 786 torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c) 787 if prepack_removal: 788 torch._C._jit_pass_fold_prepacking_ops(scripted_model._c) 789 790 buffer = io.BytesIO() 791 torch.jit.save(scripted_model, buffer) 792 buffer.seek(0) 793 deserialized_scripted_model = torch.jit.load(buffer) 794 for pattern, v in pattern_count_map.items(): 795 if v == 0: 796 FileCheck().check(pattern).run(deserialized_scripted_model.graph) 797 elif v == -1: 798 FileCheck().check_not(pattern).run( 799 deserialized_scripted_model.graph 800 ) 801 else: 802 FileCheck().check_count(pattern, v, exactly=True).run( 803 deserialized_scripted_model.graph 804 ) 805 xnnpack_result = deserialized_scripted_model(input_data) 806 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 807 808 def test_linear(self): 809 data_shape = [2, 3, 32] 810 weight_output_dim = 24 811 weight_shape = (weight_output_dim, data_shape[-1]) 812 813 class Linear(torch.nn.Module): 814 def __init__(self) -> None: 815 super().__init__() 816 self.weight = torch.nn.Parameter( 817 torch.rand(weight_shape), requires_grad=False 818 ) 819 self.bias = torch.nn.Parameter( 820 torch.rand(weight_output_dim), requires_grad=False 821 ) 822 823 def forward(self, x): 824 return F.linear(x, self.weight, self.bias) 825 826 class LinearNoBias(torch.nn.Module): 827 def __init__(self) -> None: 828 super().__init__() 829 self.weight = torch.nn.Parameter( 830 torch.rand(weight_shape), requires_grad=False 831 ) 832 833 def forward(self, x): 834 return F.linear(x, self.weight, None) 835 836 # Linear with bias pattern. 837 pattern_count_map = { 838 "Tensor = prim::CallFunction": -1, 839 "prepacked::linear_clamp_prepack": 1, 840 "prepacked::linear_clamp_run": 1, 841 } 842 TestXNNPACKRewritePass.validate_transformed_module( 843 Linear(), pattern_count_map, data_shape 844 ) 845 TestXNNPACKRewritePass.validate_transformed_module( 846 LinearNoBias(), pattern_count_map, data_shape 847 ) 848 849 # Conv params 850 batch_size = 2 851 input_channels_per_group = 6 852 height = 16 853 width = 16 854 output_channels_per_group = 6 855 groups = 4 856 kernel_h = kernel_w = 3 857 stride_h = stride_w = 1 858 pad_h = pad_w = 1 859 output_pad_h = output_pad_w = 0 860 dilation = 1 861 input_channels = input_channels_per_group * groups 862 output_channels = output_channels_per_group * groups 863 kernels = (kernel_h, kernel_w) 864 strides = (stride_h, stride_w) 865 paddings = (pad_h, pad_w) 866 output_paddings = (output_pad_h, output_pad_w) 867 dilations = (dilation, dilation) 868 conv_weight_shape = ( 869 output_channels, 870 input_channels_per_group, 871 kernel_h, 872 kernel_w, 873 ) 874 conv_transpose_weight_shape = ( 875 input_channels, 876 output_channels_per_group, 877 kernel_h, 878 kernel_w, 879 ) 880 conv_bias_shape = output_channels 881 882 class Conv2D(torch.nn.Module): 883 def __init__(self) -> None: 884 super().__init__() 885 self.weight = torch.nn.Parameter( 886 torch.rand(conv_weight_shape), requires_grad=False 887 ) 888 self.bias = torch.nn.Parameter( 889 torch.rand(conv_bias_shape), requires_grad=False 890 ) 891 self.strides = strides 892 self.paddings = paddings 893 self.dilations = dilations 894 self.groups = groups 895 896 def forward(self, x): 897 return F.conv2d( 898 x, 899 self.weight, 900 self.bias, 901 self.strides, 902 self.paddings, 903 self.dilations, 904 self.groups, 905 ) 906 907 class Conv2DT(torch.nn.Module): 908 def __init__(self) -> None: 909 super().__init__() 910 self.weight = torch.nn.Parameter( 911 torch.rand(conv_transpose_weight_shape), requires_grad=False 912 ) 913 self.bias = torch.nn.Parameter( 914 torch.rand(conv_bias_shape), requires_grad=False 915 ) 916 self.strides = strides 917 self.paddings = paddings 918 self.output_paddings = output_paddings 919 self.dilations = dilations 920 self.groups = groups 921 922 def forward(self, x): 923 return F.conv_transpose2d( 924 x, 925 self.weight, 926 self.bias, 927 self.strides, 928 self.paddings, 929 self.output_paddings, 930 self.groups, 931 self.dilations, 932 ) 933 934 data_shape = (batch_size, input_channels, height, width) 935 pattern_count_map = { 936 "Tensor = aten::conv2d": -1, 937 "prepacked::conv2d_clamp_prepack": 1, 938 "prepacked::conv2d_clamp_run": 1, 939 } 940 TestXNNPACKRewritePass.validate_transformed_module( 941 Conv2D(), pattern_count_map, data_shape 942 ) 943 944 transpose_data_shape = (batch_size, input_channels, height, width) 945 transpose_pattern_count_map = { 946 "Tensor = aten::conv_transpose2d": -1, 947 "prepacked::conv2d_transpose_clamp_prepack": 1, 948 "prepacked::conv2d_transpose_clamp_run": 1, 949 } 950 TestXNNPACKRewritePass.validate_transformed_module( 951 Conv2DT(), transpose_pattern_count_map, data_shape 952 ) 953 954 input_data = torch.rand((batch_size, input_channels, height, width)) 955 conv_weight = torch.rand( 956 (output_channels, input_channels_per_group, kernel_h, kernel_w) 957 ) 958 conv_bias = torch.rand(output_channels) 959 result = F.conv2d( 960 input_data, conv_weight, conv_bias, strides, paddings, dilations, groups 961 ) 962 linear_input_shape = result.shape[1] 963 linear_weight_shape = (weight_output_dim, linear_input_shape) 964 965 class M(torch.nn.Module): 966 def __init__(self, activation_fn=F.relu): 967 super().__init__() 968 self.conv_weight = torch.nn.Parameter( 969 torch.rand(conv_weight_shape), requires_grad=False 970 ) 971 self.conv_bias = torch.nn.Parameter( 972 torch.rand(conv_bias_shape), requires_grad=False 973 ) 974 self.linear_weight = torch.nn.Parameter( 975 torch.rand(linear_weight_shape), requires_grad=False 976 ) 977 self.linear_bias = torch.nn.Parameter( 978 torch.rand(weight_output_dim), requires_grad=False 979 ) 980 self.strides = strides 981 self.paddings = paddings 982 self.dilations = dilations 983 self.groups = groups 984 self.activation_fn = activation_fn 985 986 def forward(self, x): 987 o = F.conv2d( 988 x, 989 self.conv_weight, 990 self.conv_bias, 991 self.strides, 992 self.paddings, 993 self.dilations, 994 self.groups, 995 ) 996 o = self.activation_fn(o) 997 o = o.permute([0, 2, 3, 1]) 998 o = F.linear(o, self.linear_weight, self.linear_bias) 999 return self.activation_fn(o) 1000 1001 pattern_count_map = { 1002 "Tensor = aten::conv2d": -1, 1003 "prepacked::conv2d_clamp_prepack": 1, 1004 "prepacked::conv2d_clamp_run": 1, 1005 "prepacked::linear_clamp_prepack": 1, 1006 "prepacked::linear_clamp_run": 1, 1007 } 1008 TestXNNPACKRewritePass.validate_transformed_module( 1009 M(), pattern_count_map, data_shape 1010 ) 1011 pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1012 pattern_count_map["Tensor = prim::CallFunction"] = -1 1013 pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1014 TestXNNPACKRewritePass.validate_transformed_module( 1015 M(), pattern_count_map, data_shape, prepack_removal=True 1016 ) 1017 1018 # Not inplace relu fusion test. 1019 pattern_count_map = { 1020 "aten::relu": 2, 1021 "prepacked::conv2d_clamp_prepack": -1, 1022 "prepacked::conv2d_clamp_run": 1, 1023 "prepacked::linear_clamp_prepack": -1, 1024 "prepacked::linear_clamp_run": 1, 1025 } 1026 TestXNNPACKRewritePass.validate_transformed_module( 1027 M(), pattern_count_map, data_shape, prepack_removal=True 1028 ) 1029 pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1030 pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1031 pattern_count_map["aten::relu"] = -1 1032 TestXNNPACKRewritePass.validate_transformed_module( 1033 M(), 1034 pattern_count_map, 1035 data_shape, 1036 prepack_removal=True, 1037 fuse_clamping_ops=True, 1038 ) 1039 1040 # Inplace relu fusion test. 1041 pattern_count_map = { 1042 "aten::relu": 2, 1043 "prepacked::conv2d_clamp_prepack": -1, 1044 "prepacked::conv2d_clamp_run": 1, 1045 "prepacked::linear_clamp_prepack": -1, 1046 "prepacked::linear_clamp_run": 1, 1047 } 1048 TestXNNPACKRewritePass.validate_transformed_module( 1049 M(F.relu_), pattern_count_map, data_shape, prepack_removal=True 1050 ) 1051 pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1052 pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1053 pattern_count_map["aten::relu"] = -1 1054 TestXNNPACKRewritePass.validate_transformed_module( 1055 M(F.relu_), 1056 pattern_count_map, 1057 data_shape, 1058 prepack_removal=True, 1059 fuse_clamping_ops=True, 1060 ) 1061 1062 # Not inplace hardtanh fusion test. 1063 pattern_count_map = { 1064 "aten::hardtanh": 2, 1065 "prepacked::conv2d_clamp_prepack": -1, 1066 "prepacked::conv2d_clamp_run": 1, 1067 "prepacked::linear_clamp_prepack": -1, 1068 "prepacked::linear_clamp_run": 1, 1069 } 1070 TestXNNPACKRewritePass.validate_transformed_module( 1071 M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True 1072 ) 1073 pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1074 pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1075 pattern_count_map["aten::hardtanh"] = -1 1076 TestXNNPACKRewritePass.validate_transformed_module( 1077 M(F.hardtanh), 1078 pattern_count_map, 1079 data_shape, 1080 prepack_removal=True, 1081 fuse_clamping_ops=True, 1082 ) 1083 1084 # Inplace hardtanh fusion test. 1085 pattern_count_map = { 1086 "aten::hardtanh_": 2, 1087 "prepacked::conv2d_clamp_prepack": -1, 1088 "prepacked::conv2d_clamp_run": 1, 1089 "prepacked::linear_clamp_prepack": -1, 1090 "prepacked::linear_clamp_run": 1, 1091 } 1092 TestXNNPACKRewritePass.validate_transformed_module( 1093 M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True 1094 ) 1095 pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1096 pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1097 pattern_count_map["aten::hardtanh_"] = -1 1098 TestXNNPACKRewritePass.validate_transformed_module( 1099 M(F.hardtanh_), 1100 pattern_count_map, 1101 data_shape, 1102 prepack_removal=True, 1103 fuse_clamping_ops=True, 1104 ) 1105 1106 class MFusionAntiPattern(torch.nn.Module): 1107 def __init__(self) -> None: 1108 super().__init__() 1109 self.linear_weight = torch.nn.Parameter( 1110 torch.rand(linear_weight_shape), requires_grad=False 1111 ) 1112 self.linear_bias = torch.nn.Parameter( 1113 torch.rand(weight_output_dim), requires_grad=False 1114 ) 1115 self.strides = strides 1116 self.paddings = paddings 1117 self.dilations = dilations 1118 self.groups = groups 1119 1120 def forward(self, x): 1121 o = F.linear(x, self.linear_weight, self.linear_bias) 1122 o = F.relu(o) 1123 o = F.hardtanh(o) 1124 return o 1125 1126 # Unfusable hardtanh. 1127 pattern_count_map = { 1128 "aten::hardtanh": 1, # hardtanh cannot be. 1129 "aten::relu": -1, # relu is fused. 1130 "prepacked::linear_clamp_prepack": -1, 1131 "prepacked::linear_clamp_run": 1, 1132 } 1133 TestXNNPACKRewritePass.validate_transformed_module( 1134 MFusionAntiPattern(), 1135 pattern_count_map, 1136 (16, linear_weight_shape[1]), 1137 prepack_removal=True, 1138 fuse_clamping_ops=True, 1139 ) 1140 1141 class MFusionAntiPatternParamMinMax(torch.nn.Module): 1142 def __init__(self) -> None: 1143 super().__init__() 1144 self.linear_weight = torch.nn.Parameter( 1145 torch.rand(linear_weight_shape), requires_grad=False 1146 ) 1147 self.linear_bias = torch.nn.Parameter( 1148 torch.rand(weight_output_dim), requires_grad=False 1149 ) 1150 self.strides = strides 1151 self.paddings = paddings 1152 self.dilations = dilations 1153 self.groups = groups 1154 1155 def forward(self, x): 1156 min = x[0, 0] 1157 max = min + 10 1158 o = F.linear(x, self.linear_weight, self.linear_bias) 1159 o = F.hardtanh(o, min, max) 1160 return o 1161 1162 # Unfusable hardtanh. 1163 pattern_count_map = { 1164 "aten::hardtanh": 1, # hardtanh cannot be. 1165 "prepacked::linear_clamp_prepack": -1, 1166 "prepacked::linear_clamp_run": 1, 1167 } 1168 TestXNNPACKRewritePass.validate_transformed_module( 1169 MFusionAntiPatternParamMinMax(), 1170 pattern_count_map, 1171 (16, linear_weight_shape[1]), 1172 prepack_removal=True, 1173 fuse_clamping_ops=True, 1174 ) 1175 1176 def test_decomposed_linear(self): 1177 data_shape = [2, 32] 1178 weight_output_dim = 24 1179 weight_shape = (weight_output_dim, data_shape[-1]) 1180 1181 class DecomposedLinearAddmm(torch.nn.Module): 1182 def __init__(self) -> None: 1183 super().__init__() 1184 self.weight = torch.nn.Parameter( 1185 torch.rand(weight_shape), requires_grad=False 1186 ) 1187 self.bias = torch.nn.Parameter( 1188 torch.rand(weight_output_dim), requires_grad=False 1189 ) 1190 1191 def forward(self, x): 1192 weight_t = self.weight.t() 1193 return torch.addmm(self.bias, x, weight_t) 1194 1195 class DecomposedLinearMatmulAdd(torch.nn.Module): 1196 def __init__(self) -> None: 1197 super().__init__() 1198 self.weight = torch.nn.Parameter( 1199 torch.rand(weight_shape), requires_grad=False 1200 ) 1201 self.bias = torch.nn.Parameter( 1202 torch.rand(weight_output_dim), requires_grad=False 1203 ) 1204 1205 def forward(self, x): 1206 weight_t = self.weight.t() 1207 y = torch.matmul(x, weight_t) 1208 res = y.add_(self.bias) 1209 return res 1210 1211 class DecomposedLinearMatmul(torch.nn.Module): 1212 def __init__(self) -> None: 1213 super().__init__() 1214 self.weight = torch.nn.Parameter( 1215 torch.rand(weight_shape), requires_grad=False 1216 ) 1217 self.bias = torch.nn.Parameter( 1218 torch.rand(weight_output_dim), requires_grad=False 1219 ) 1220 1221 def forward(self, x): 1222 weight_t = self.weight.t() 1223 res = torch.matmul(x, weight_t) 1224 return res 1225 1226 # Linear with bias pattern. 1227 pattern_count_map = { 1228 "Tensor = prim::CallFunction": -1, 1229 "prepacked::linear_clamp_prepack": 1, 1230 "prepacked::linear_clamp_run": 1, 1231 } 1232 TestXNNPACKRewritePass.validate_transformed_module( 1233 DecomposedLinearAddmm(), pattern_count_map, data_shape 1234 ) 1235 TestXNNPACKRewritePass.validate_transformed_module( 1236 DecomposedLinearMatmulAdd(), pattern_count_map, data_shape 1237 ) 1238 TestXNNPACKRewritePass.validate_transformed_module( 1239 DecomposedLinearMatmul(), pattern_count_map, data_shape 1240 ) 1241 1242 1243@unittest.skipUnless( 1244 torch.backends.xnnpack.enabled, 1245 " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 1246) 1247@unittest.skipIf( 1248 TEST_WITH_TSAN, 1249 "TSAN is not fork-safe since we're forking in a multi-threaded environment", 1250) 1251class TestXNNPACKConv1dTransformPass(TestCase): 1252 @staticmethod 1253 def validate_transform_conv1d_to_conv2d( 1254 self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape 1255 ): 1256 input_data = torch.normal(1, 20, size=data_shape) 1257 1258 for jit_method in ["script", "trace"]: 1259 module_instance = self 1260 if jit_method == "script": 1261 scripted_model = torch.jit.script(module_instance) 1262 else: 1263 scripted_model = torch.jit.trace(module_instance, input_data) 1264 scripted_model.eval() 1265 ref_result = scripted_model(input_data) 1266 torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) 1267 optimized_scripted_model = optimize_for_mobile(scripted_model) 1268 1269 buffer = io.BytesIO() 1270 torch.jit.save(scripted_model, buffer) 1271 buffer.seek(0) 1272 deserialized_scripted_model = torch.jit.load(buffer) 1273 1274 for pattern, v in pattern_count_transformed_map.items(): 1275 if v == 0: 1276 FileCheck().check(pattern).run(deserialized_scripted_model.graph) 1277 elif v == -1: 1278 FileCheck().check_not(pattern).run( 1279 deserialized_scripted_model.graph 1280 ) 1281 else: 1282 FileCheck().check_count(pattern, v, exactly=True).run( 1283 deserialized_scripted_model.graph 1284 ) 1285 transformed_result = deserialized_scripted_model(input_data) 1286 torch.testing.assert_close( 1287 ref_result, transformed_result, rtol=1e-2, atol=1e-3 1288 ) 1289 1290 optimized_buffer = io.BytesIO() 1291 torch.jit.save(optimized_scripted_model, optimized_buffer) 1292 optimized_buffer.seek(0) 1293 deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer) 1294 1295 for pattern, v in pattern_count_optimized_map.items(): 1296 if v == 0: 1297 FileCheck().check(pattern).run( 1298 deserialized_optimized_scripted_model.graph 1299 ) 1300 elif v == -1: 1301 FileCheck().check_not(pattern).run( 1302 deserialized_optimized_scripted_model.graph 1303 ) 1304 else: 1305 FileCheck().check_count(pattern, v, exactly=True).run( 1306 deserialized_optimized_scripted_model.graph 1307 ) 1308 xnnpack_result = deserialized_optimized_scripted_model(input_data) 1309 torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 1310 1311 @unittest.skipIf(IS_FBCODE, "T137513244") 1312 def test_conv1d_basic(self): 1313 batch_size_list = range(1, 3) 1314 input_channels_per_group_list = range(10, 12) 1315 width_list = range(10, 12) 1316 output_channels_per_group_list = range(10, 12) 1317 groups_list = range(1, 3) 1318 kernel_list = range(1, 4) 1319 stride_list = range(1, 3) 1320 padding_list = range(0, 3) 1321 dilation_list = range(1, 3) 1322 1323 for hparams in itertools.product( 1324 batch_size_list, 1325 input_channels_per_group_list, 1326 width_list, 1327 output_channels_per_group_list, 1328 groups_list, 1329 kernel_list, 1330 stride_list, 1331 padding_list, 1332 dilation_list, 1333 ): 1334 ( 1335 batch_size, 1336 input_channels_per_group, 1337 width, 1338 output_channels_per_group, 1339 groups, 1340 kernel, 1341 stride, 1342 padding, 1343 dilation, 1344 ) = hparams 1345 1346 input_channels = input_channels_per_group * groups 1347 output_channels = output_channels_per_group * groups 1348 conv_weight_shape = (output_channels, input_channels_per_group, kernel) 1349 conv_bias_shape = output_channels 1350 1351 class Conv1D(torch.nn.Module): 1352 def __init__(self) -> None: 1353 super().__init__() 1354 self.weight = torch.nn.Parameter( 1355 torch.rand(conv_weight_shape), requires_grad=False 1356 ) 1357 self.bias = torch.nn.Parameter( 1358 torch.rand(conv_bias_shape), requires_grad=False 1359 ) 1360 self.stride = stride 1361 self.padding = padding 1362 self.dilation = dilation 1363 self.groups = groups 1364 1365 def forward(self, x): 1366 return F.conv1d( 1367 x, 1368 self.weight, 1369 self.bias, 1370 self.stride, 1371 self.padding, 1372 self.dilation, 1373 self.groups, 1374 ) 1375 1376 data_shape = (batch_size, input_channels, width) 1377 pattern_count_transformed_map = { 1378 "Tensor = aten::conv1d": -1, 1379 "Tensor = aten::conv2d": 1, 1380 } 1381 pattern_count_optimized_map = { 1382 "Tensor = aten::conv1d": -1, 1383 "Tensor = aten::conv2d": -1, 1384 "prepacked::conv2d_clamp_prepack": -1, 1385 "prepacked::conv2d_clamp_run": 1, 1386 } 1387 1388 TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( 1389 Conv1D(), 1390 pattern_count_transformed_map, 1391 pattern_count_optimized_map, 1392 data_shape, 1393 ) 1394 1395 # See https://github.com/pytorch/pytorch/issues/46066 1396 @slowTest 1397 def test_conv1d_with_relu_fc(self): 1398 batch_size_list = range(1, 3) 1399 input_channels_per_group_list = range(10, 12) 1400 width_list = range(10, 12) 1401 output_channels_per_group_list = range(10, 12) 1402 groups_list = range(1, 3) 1403 kernel_list = range(1, 4) 1404 stride_list = range(1, 3) 1405 padding_list = range(0, 3) 1406 dilation_list = range(1, 3) 1407 output_features_list = range(1, 3) 1408 1409 for hparams in itertools.product( 1410 batch_size_list, 1411 input_channels_per_group_list, 1412 width_list, 1413 output_channels_per_group_list, 1414 groups_list, 1415 kernel_list, 1416 stride_list, 1417 padding_list, 1418 dilation_list, 1419 output_features_list, 1420 ): 1421 ( 1422 batch_size, 1423 input_channels_per_group, 1424 width, 1425 output_channels_per_group, 1426 groups, 1427 kernel, 1428 stride, 1429 padding, 1430 dilation, 1431 output_features, 1432 ) = hparams 1433 1434 input_channels = input_channels_per_group * groups 1435 output_channels = output_channels_per_group * groups 1436 conv_weight_shape = (output_channels, input_channels_per_group, kernel) 1437 conv_bias_shape = output_channels 1438 conv_output_width = ( 1439 int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1 1440 ) 1441 fc_weight_shape = (output_features, output_channels * conv_output_width) 1442 fc_bias_shape = output_features 1443 1444 class Net(torch.nn.Module): 1445 def __init__(self) -> None: 1446 super().__init__() 1447 self.conv_weight = torch.nn.Parameter( 1448 torch.rand(conv_weight_shape), requires_grad=False 1449 ) 1450 self.conv_bias = torch.nn.Parameter( 1451 torch.rand(conv_bias_shape), requires_grad=False 1452 ) 1453 self.stride = stride 1454 self.padding = padding 1455 self.dilation = dilation 1456 self.groups = groups 1457 1458 self.fc_weight = torch.nn.Parameter( 1459 torch.rand(fc_weight_shape), requires_grad=False 1460 ) 1461 self.fc_bias = torch.nn.Parameter( 1462 torch.rand(fc_bias_shape), requires_grad=False 1463 ) 1464 1465 def forward(self, x): 1466 x = F.conv1d( 1467 x, 1468 self.conv_weight, 1469 self.conv_bias, 1470 self.stride, 1471 self.padding, 1472 self.dilation, 1473 self.groups, 1474 ) 1475 x = F.relu(x) 1476 x = x.view(x.size(0), -1) 1477 x = F.linear(x, self.fc_weight, self.fc_bias) 1478 return x 1479 1480 data_shape = (batch_size, input_channels, width) 1481 pattern_count_transformed_map = { 1482 "Tensor = aten::conv1d": -1, 1483 "Tensor = aten::conv2d": 1, 1484 } 1485 pattern_count_optimized_map = { 1486 "Tensor = aten::conv1d": -1, 1487 "Tensor = aten::conv2d": -1, 1488 "prepacked::conv2d_clamp_prepack": -1, 1489 "prepacked::conv2d_clamp_run": 1, 1490 } 1491 TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( 1492 Net(), 1493 pattern_count_transformed_map, 1494 pattern_count_optimized_map, 1495 data_shape, 1496 ) 1497 1498 1499if __name__ == "__main__": 1500 run_tests() 1501