1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import assume, given, strategies as st 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.backends.xnnpack 11*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu 12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 15*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 16*da0073e9SAndroid Build Coastguard Worker run_tests, 17*da0073e9SAndroid Build Coastguard Worker slowTest, 18*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 19*da0073e9SAndroid Build Coastguard Worker TestCase, 20*da0073e9SAndroid Build Coastguard Worker) 21*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.mobile_optimizer import optimize_for_mobile 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless( 25*da0073e9SAndroid Build Coastguard Worker torch.backends.xnnpack.enabled, 26*da0073e9SAndroid Build Coastguard Worker " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 29*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 30*da0073e9SAndroid Build Coastguard Worker "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKOps(TestCase): 33*da0073e9SAndroid Build Coastguard Worker @unittest.skip( 34*da0073e9SAndroid Build Coastguard Worker "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 35*da0073e9SAndroid Build Coastguard Worker ) 36*da0073e9SAndroid Build Coastguard Worker @given( 37*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 38*da0073e9SAndroid Build Coastguard Worker data_shape=hu.array_shapes(1, 3, 2, 64), 39*da0073e9SAndroid Build Coastguard Worker weight_output_dim=st.integers(2, 64), 40*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 41*da0073e9SAndroid Build Coastguard Worker ) 42*da0073e9SAndroid Build Coastguard Worker def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): 43*da0073e9SAndroid Build Coastguard Worker data_shape = [batch_size] + list(data_shape) 44*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand(data_shape) 45*da0073e9SAndroid Build Coastguard Worker weight = torch.rand((weight_output_dim, data_shape[-1])) 46*da0073e9SAndroid Build Coastguard Worker if use_bias: 47*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(weight_output_dim) 48*da0073e9SAndroid Build Coastguard Worker else: 49*da0073e9SAndroid Build Coastguard Worker bias = None 50*da0073e9SAndroid Build Coastguard Worker ref_result = F.linear(input_data, weight, bias) 51*da0073e9SAndroid Build Coastguard Worker packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) 52*da0073e9SAndroid Build Coastguard Worker output_linearprepacked = torch.ops.prepacked.linear_clamp_run( 53*da0073e9SAndroid Build Coastguard Worker input_data, packed_weight_bias 54*da0073e9SAndroid Build Coastguard Worker ) 55*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 56*da0073e9SAndroid Build Coastguard Worker ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker @given( 60*da0073e9SAndroid Build Coastguard Worker input_size=st.integers(2, 32), 61*da0073e9SAndroid Build Coastguard Worker weight_output_dim=st.integers(2, 64), 62*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 63*da0073e9SAndroid Build Coastguard Worker ) 64*da0073e9SAndroid Build Coastguard Worker def test_linear_1d_input(self, input_size, weight_output_dim, use_bias): 65*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand(input_size) 66*da0073e9SAndroid Build Coastguard Worker weight = torch.rand((weight_output_dim, input_data.shape[-1])) 67*da0073e9SAndroid Build Coastguard Worker if use_bias: 68*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(weight_output_dim) 69*da0073e9SAndroid Build Coastguard Worker else: 70*da0073e9SAndroid Build Coastguard Worker bias = None 71*da0073e9SAndroid Build Coastguard Worker ref_result = F.linear(input_data, weight, bias) 72*da0073e9SAndroid Build Coastguard Worker packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) 73*da0073e9SAndroid Build Coastguard Worker output_linearprepacked = torch.ops.prepacked.linear_clamp_run( 74*da0073e9SAndroid Build Coastguard Worker input_data, packed_weight_bias 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 77*da0073e9SAndroid Build Coastguard Worker ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker @given( 81*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 82*da0073e9SAndroid Build Coastguard Worker input_channels_per_group=st.integers(1, 32), 83*da0073e9SAndroid Build Coastguard Worker height=st.integers(5, 64), 84*da0073e9SAndroid Build Coastguard Worker width=st.integers(5, 64), 85*da0073e9SAndroid Build Coastguard Worker output_channels_per_group=st.integers(1, 32), 86*da0073e9SAndroid Build Coastguard Worker groups=st.integers(1, 16), 87*da0073e9SAndroid Build Coastguard Worker kernel_h=st.integers(1, 7), 88*da0073e9SAndroid Build Coastguard Worker kernel_w=st.integers(1, 7), 89*da0073e9SAndroid Build Coastguard Worker stride_h=st.integers(1, 2), 90*da0073e9SAndroid Build Coastguard Worker stride_w=st.integers(1, 2), 91*da0073e9SAndroid Build Coastguard Worker pad_h=st.integers(0, 2), 92*da0073e9SAndroid Build Coastguard Worker pad_w=st.integers(0, 2), 93*da0073e9SAndroid Build Coastguard Worker dilation=st.integers(1, 2), 94*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 95*da0073e9SAndroid Build Coastguard Worker format=st.sampled_from( 96*da0073e9SAndroid Build Coastguard Worker [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 97*da0073e9SAndroid Build Coastguard Worker ), 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker def test_conv2d( 100*da0073e9SAndroid Build Coastguard Worker self, 101*da0073e9SAndroid Build Coastguard Worker batch_size, 102*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 103*da0073e9SAndroid Build Coastguard Worker height, 104*da0073e9SAndroid Build Coastguard Worker width, 105*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 106*da0073e9SAndroid Build Coastguard Worker groups, 107*da0073e9SAndroid Build Coastguard Worker kernel_h, 108*da0073e9SAndroid Build Coastguard Worker kernel_w, 109*da0073e9SAndroid Build Coastguard Worker stride_h, 110*da0073e9SAndroid Build Coastguard Worker stride_w, 111*da0073e9SAndroid Build Coastguard Worker pad_h, 112*da0073e9SAndroid Build Coastguard Worker pad_w, 113*da0073e9SAndroid Build Coastguard Worker dilation, 114*da0073e9SAndroid Build Coastguard Worker use_bias, 115*da0073e9SAndroid Build Coastguard Worker format, 116*da0073e9SAndroid Build Coastguard Worker ): 117*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 118*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 119*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 120*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 121*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 122*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 123*da0073e9SAndroid Build Coastguard Worker assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 124*da0073e9SAndroid Build Coastguard Worker assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 127*da0073e9SAndroid Build Coastguard Worker if format is not None: 128*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 129*da0073e9SAndroid Build Coastguard Worker weight = torch.rand( 130*da0073e9SAndroid Build Coastguard Worker (output_channels, input_channels_per_group, kernel_h, kernel_w) 131*da0073e9SAndroid Build Coastguard Worker ) 132*da0073e9SAndroid Build Coastguard Worker bias = None 133*da0073e9SAndroid Build Coastguard Worker if use_bias: 134*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(output_channels) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker ref_result = F.conv2d( 137*da0073e9SAndroid Build Coastguard Worker input_data, weight, bias, strides, paddings, dilations, groups 138*da0073e9SAndroid Build Coastguard Worker ) 139*da0073e9SAndroid Build Coastguard Worker packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( 140*da0073e9SAndroid Build Coastguard Worker weight, bias, strides, paddings, dilations, groups 141*da0073e9SAndroid Build Coastguard Worker ) 142*da0073e9SAndroid Build Coastguard Worker xnnpack_result = torch.ops.prepacked.conv2d_clamp_run( 143*da0073e9SAndroid Build Coastguard Worker input_data, packed_weight_bias 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker @given( 148*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(1, 3), 149*da0073e9SAndroid Build Coastguard Worker input_channels_per_group=st.integers(1, 32), 150*da0073e9SAndroid Build Coastguard Worker height=st.integers(5, 64), 151*da0073e9SAndroid Build Coastguard Worker width=st.integers(5, 64), 152*da0073e9SAndroid Build Coastguard Worker output_channels_per_group=st.integers(1, 32), 153*da0073e9SAndroid Build Coastguard Worker groups=st.integers(1, 16), 154*da0073e9SAndroid Build Coastguard Worker kernel_h=st.integers(1, 7), 155*da0073e9SAndroid Build Coastguard Worker kernel_w=st.integers(1, 7), 156*da0073e9SAndroid Build Coastguard Worker stride_h=st.integers(1, 2), 157*da0073e9SAndroid Build Coastguard Worker stride_w=st.integers(1, 2), 158*da0073e9SAndroid Build Coastguard Worker pad_h=st.integers(0, 2), 159*da0073e9SAndroid Build Coastguard Worker pad_w=st.integers(0, 2), 160*da0073e9SAndroid Build Coastguard Worker output_pad_h=st.integers(0, 2), 161*da0073e9SAndroid Build Coastguard Worker output_pad_w=st.integers(0, 2), 162*da0073e9SAndroid Build Coastguard Worker dilation=st.integers(1, 2), 163*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 164*da0073e9SAndroid Build Coastguard Worker format=st.sampled_from( 165*da0073e9SAndroid Build Coastguard Worker [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 166*da0073e9SAndroid Build Coastguard Worker ), 167*da0073e9SAndroid Build Coastguard Worker ) 168*da0073e9SAndroid Build Coastguard Worker def test_conv2d_transpose( 169*da0073e9SAndroid Build Coastguard Worker self, 170*da0073e9SAndroid Build Coastguard Worker batch_size, 171*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 172*da0073e9SAndroid Build Coastguard Worker height, 173*da0073e9SAndroid Build Coastguard Worker width, 174*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 175*da0073e9SAndroid Build Coastguard Worker groups, 176*da0073e9SAndroid Build Coastguard Worker kernel_h, 177*da0073e9SAndroid Build Coastguard Worker kernel_w, 178*da0073e9SAndroid Build Coastguard Worker stride_h, 179*da0073e9SAndroid Build Coastguard Worker stride_w, 180*da0073e9SAndroid Build Coastguard Worker pad_h, 181*da0073e9SAndroid Build Coastguard Worker pad_w, 182*da0073e9SAndroid Build Coastguard Worker output_pad_h, 183*da0073e9SAndroid Build Coastguard Worker output_pad_w, 184*da0073e9SAndroid Build Coastguard Worker dilation, 185*da0073e9SAndroid Build Coastguard Worker use_bias, 186*da0073e9SAndroid Build Coastguard Worker format, 187*da0073e9SAndroid Build Coastguard Worker ): 188*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 189*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 190*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 191*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 192*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 193*da0073e9SAndroid Build Coastguard Worker output_paddings = (output_pad_h, output_pad_w) 194*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 195*da0073e9SAndroid Build Coastguard Worker assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 196*da0073e9SAndroid Build Coastguard Worker assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 197*da0073e9SAndroid Build Coastguard Worker assume((output_pad_h < stride_h) and (output_pad_h < dilation)) 198*da0073e9SAndroid Build Coastguard Worker assume((output_pad_w < stride_w) and (output_pad_w < dilation)) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 201*da0073e9SAndroid Build Coastguard Worker if format is not None: 202*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 203*da0073e9SAndroid Build Coastguard Worker weight = torch.rand( 204*da0073e9SAndroid Build Coastguard Worker (input_channels, output_channels_per_group, kernel_h, kernel_w) 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker bias = None 207*da0073e9SAndroid Build Coastguard Worker if use_bias: 208*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(output_channels) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker # Note that groups/dilation is in reverse order from conv2d 211*da0073e9SAndroid Build Coastguard Worker ref_result = F.conv_transpose2d( 212*da0073e9SAndroid Build Coastguard Worker input_data, 213*da0073e9SAndroid Build Coastguard Worker weight, 214*da0073e9SAndroid Build Coastguard Worker bias, 215*da0073e9SAndroid Build Coastguard Worker strides, 216*da0073e9SAndroid Build Coastguard Worker paddings, 217*da0073e9SAndroid Build Coastguard Worker output_paddings, 218*da0073e9SAndroid Build Coastguard Worker groups, 219*da0073e9SAndroid Build Coastguard Worker dilation, 220*da0073e9SAndroid Build Coastguard Worker ) 221*da0073e9SAndroid Build Coastguard Worker packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack( 222*da0073e9SAndroid Build Coastguard Worker weight, bias, strides, paddings, output_paddings, dilations, groups 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run( 225*da0073e9SAndroid Build Coastguard Worker input_data, packed_weight_bias 226*da0073e9SAndroid Build Coastguard Worker ) 227*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 228*da0073e9SAndroid Build Coastguard Worker ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3 229*da0073e9SAndroid Build Coastguard Worker ) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless( 233*da0073e9SAndroid Build Coastguard Worker torch.backends.xnnpack.enabled, 234*da0073e9SAndroid Build Coastguard Worker " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 235*da0073e9SAndroid Build Coastguard Worker) 236*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 237*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 238*da0073e9SAndroid Build Coastguard Worker "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 239*da0073e9SAndroid Build Coastguard Worker) 240*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKSerDes(TestCase): 241*da0073e9SAndroid Build Coastguard Worker @unittest.skip( 242*da0073e9SAndroid Build Coastguard Worker "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 243*da0073e9SAndroid Build Coastguard Worker ) 244*da0073e9SAndroid Build Coastguard Worker @given( 245*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 246*da0073e9SAndroid Build Coastguard Worker data_shape=hu.array_shapes(1, 3, 2, 64), 247*da0073e9SAndroid Build Coastguard Worker weight_output_dim=st.integers(2, 64), 248*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 249*da0073e9SAndroid Build Coastguard Worker ) 250*da0073e9SAndroid Build Coastguard Worker def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): 251*da0073e9SAndroid Build Coastguard Worker class Linear(torch.nn.Module): 252*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight, bias=None): 253*da0073e9SAndroid Build Coastguard Worker super().__init__() 254*da0073e9SAndroid Build Coastguard Worker self.weight = weight 255*da0073e9SAndroid Build Coastguard Worker self.bias = bias 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 258*da0073e9SAndroid Build Coastguard Worker return F.linear(x, self.weight, self.bias) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker class LinearPrePacked(torch.nn.Module): 261*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight, bias=None): 262*da0073e9SAndroid Build Coastguard Worker super().__init__() 263*da0073e9SAndroid Build Coastguard Worker self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack( 264*da0073e9SAndroid Build Coastguard Worker weight, bias 265*da0073e9SAndroid Build Coastguard Worker ) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 268*da0073e9SAndroid Build Coastguard Worker return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker data_shape = [batch_size] + list(data_shape) 271*da0073e9SAndroid Build Coastguard Worker weight = torch.rand((weight_output_dim, data_shape[-1])) 272*da0073e9SAndroid Build Coastguard Worker if use_bias: 273*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(weight_output_dim) 274*da0073e9SAndroid Build Coastguard Worker else: 275*da0073e9SAndroid Build Coastguard Worker bias = None 276*da0073e9SAndroid Build Coastguard Worker scripted_linear = torch.jit.script(Linear(weight, bias)) 277*da0073e9SAndroid Build Coastguard Worker scripted_linear_clamp_prepacked = torch.jit.script( 278*da0073e9SAndroid Build Coastguard Worker LinearPrePacked(weight, bias) 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand(data_shape) 281*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_linear(input_data) 282*da0073e9SAndroid Build Coastguard Worker output_linearprepacked = scripted_linear_clamp_prepacked(input_data) 283*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 284*da0073e9SAndroid Build Coastguard Worker ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker # Serialize the modules and then deserialize 288*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand(data_shape) 289*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 290*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_linear, buffer) 291*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 292*da0073e9SAndroid Build Coastguard Worker deserialized_linear = torch.jit.load(buffer) 293*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 294*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_linear_clamp_prepacked, buffer) 295*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 296*da0073e9SAndroid Build Coastguard Worker deserialized_linear_clamp_prepacked = torch.jit.load(buffer) 297*da0073e9SAndroid Build Coastguard Worker ref_result = deserialized_linear(input_data) 298*da0073e9SAndroid Build Coastguard Worker output_linearprepacked = deserialized_linear_clamp_prepacked(input_data) 299*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 300*da0073e9SAndroid Build Coastguard Worker ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker @given( 304*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 305*da0073e9SAndroid Build Coastguard Worker input_channels_per_group=st.integers(1, 32), 306*da0073e9SAndroid Build Coastguard Worker height=st.integers(5, 64), 307*da0073e9SAndroid Build Coastguard Worker width=st.integers(5, 64), 308*da0073e9SAndroid Build Coastguard Worker output_channels_per_group=st.integers(1, 32), 309*da0073e9SAndroid Build Coastguard Worker groups=st.integers(1, 16), 310*da0073e9SAndroid Build Coastguard Worker kernel_h=st.integers(1, 7), 311*da0073e9SAndroid Build Coastguard Worker kernel_w=st.integers(1, 7), 312*da0073e9SAndroid Build Coastguard Worker stride_h=st.integers(1, 2), 313*da0073e9SAndroid Build Coastguard Worker stride_w=st.integers(1, 2), 314*da0073e9SAndroid Build Coastguard Worker pad_h=st.integers(0, 2), 315*da0073e9SAndroid Build Coastguard Worker pad_w=st.integers(0, 2), 316*da0073e9SAndroid Build Coastguard Worker dilation=st.integers(1, 2), 317*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 318*da0073e9SAndroid Build Coastguard Worker format=st.sampled_from( 319*da0073e9SAndroid Build Coastguard Worker [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 320*da0073e9SAndroid Build Coastguard Worker ), 321*da0073e9SAndroid Build Coastguard Worker ) 322*da0073e9SAndroid Build Coastguard Worker def test_conv2d( 323*da0073e9SAndroid Build Coastguard Worker self, 324*da0073e9SAndroid Build Coastguard Worker batch_size, 325*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 326*da0073e9SAndroid Build Coastguard Worker height, 327*da0073e9SAndroid Build Coastguard Worker width, 328*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 329*da0073e9SAndroid Build Coastguard Worker groups, 330*da0073e9SAndroid Build Coastguard Worker kernel_h, 331*da0073e9SAndroid Build Coastguard Worker kernel_w, 332*da0073e9SAndroid Build Coastguard Worker stride_h, 333*da0073e9SAndroid Build Coastguard Worker stride_w, 334*da0073e9SAndroid Build Coastguard Worker pad_h, 335*da0073e9SAndroid Build Coastguard Worker pad_w, 336*da0073e9SAndroid Build Coastguard Worker dilation, 337*da0073e9SAndroid Build Coastguard Worker use_bias, 338*da0073e9SAndroid Build Coastguard Worker format, 339*da0073e9SAndroid Build Coastguard Worker ): 340*da0073e9SAndroid Build Coastguard Worker class Conv2D(torch.nn.Module): 341*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight, bias, strides, paddings, dilations, groups): 342*da0073e9SAndroid Build Coastguard Worker super().__init__() 343*da0073e9SAndroid Build Coastguard Worker self.weight = weight 344*da0073e9SAndroid Build Coastguard Worker self.bias = bias 345*da0073e9SAndroid Build Coastguard Worker self.strides = strides 346*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 347*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 348*da0073e9SAndroid Build Coastguard Worker self.groups = groups 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 351*da0073e9SAndroid Build Coastguard Worker return F.conv2d( 352*da0073e9SAndroid Build Coastguard Worker x, 353*da0073e9SAndroid Build Coastguard Worker self.weight, 354*da0073e9SAndroid Build Coastguard Worker self.bias, 355*da0073e9SAndroid Build Coastguard Worker self.strides, 356*da0073e9SAndroid Build Coastguard Worker self.paddings, 357*da0073e9SAndroid Build Coastguard Worker self.dilations, 358*da0073e9SAndroid Build Coastguard Worker self.groups, 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker class Conv2DPrePacked(torch.nn.Module): 362*da0073e9SAndroid Build Coastguard Worker def __init__(self, weight, bias, strides, paddings, dilations, groups): 363*da0073e9SAndroid Build Coastguard Worker super().__init__() 364*da0073e9SAndroid Build Coastguard Worker self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( 365*da0073e9SAndroid Build Coastguard Worker weight, bias, strides, paddings, dilations, groups 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 369*da0073e9SAndroid Build Coastguard Worker return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 372*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 373*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 374*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 375*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 376*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 377*da0073e9SAndroid Build Coastguard Worker assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 378*da0073e9SAndroid Build Coastguard Worker assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 381*da0073e9SAndroid Build Coastguard Worker if format is not None: 382*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 383*da0073e9SAndroid Build Coastguard Worker weight = torch.rand( 384*da0073e9SAndroid Build Coastguard Worker (output_channels, input_channels_per_group, kernel_h, kernel_w) 385*da0073e9SAndroid Build Coastguard Worker ) 386*da0073e9SAndroid Build Coastguard Worker bias = None 387*da0073e9SAndroid Build Coastguard Worker if use_bias: 388*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(output_channels) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker scripted_conv2d = torch.jit.script( 391*da0073e9SAndroid Build Coastguard Worker Conv2D(weight, bias, strides, paddings, dilations, groups) 392*da0073e9SAndroid Build Coastguard Worker ) 393*da0073e9SAndroid Build Coastguard Worker scripted_conv2d_clamp_prepacked = torch.jit.script( 394*da0073e9SAndroid Build Coastguard Worker Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups) 395*da0073e9SAndroid Build Coastguard Worker ) 396*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_conv2d(input_data) 397*da0073e9SAndroid Build Coastguard Worker xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) 398*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker # Serialize the modules and then deserialize 401*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 402*da0073e9SAndroid Build Coastguard Worker if format is not None: 403*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 404*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 405*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_conv2d, buffer) 406*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 407*da0073e9SAndroid Build Coastguard Worker deserialized_conv2d = torch.jit.load(buffer) 408*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 409*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) 410*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 411*da0073e9SAndroid Build Coastguard Worker deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) 412*da0073e9SAndroid Build Coastguard Worker ref_result = deserialized_conv2d(input_data) 413*da0073e9SAndroid Build Coastguard Worker xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) 414*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker @given( 417*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 418*da0073e9SAndroid Build Coastguard Worker input_channels_per_group=st.integers(1, 32), 419*da0073e9SAndroid Build Coastguard Worker height=st.integers(5, 64), 420*da0073e9SAndroid Build Coastguard Worker width=st.integers(5, 64), 421*da0073e9SAndroid Build Coastguard Worker output_channels_per_group=st.integers(1, 32), 422*da0073e9SAndroid Build Coastguard Worker groups=st.integers(1, 16), 423*da0073e9SAndroid Build Coastguard Worker kernel_h=st.integers(1, 7), 424*da0073e9SAndroid Build Coastguard Worker kernel_w=st.integers(1, 7), 425*da0073e9SAndroid Build Coastguard Worker stride_h=st.integers(1, 2), 426*da0073e9SAndroid Build Coastguard Worker stride_w=st.integers(1, 2), 427*da0073e9SAndroid Build Coastguard Worker pad_h=st.integers(0, 2), 428*da0073e9SAndroid Build Coastguard Worker pad_w=st.integers(0, 2), 429*da0073e9SAndroid Build Coastguard Worker output_pad_h=st.integers(0, 2), 430*da0073e9SAndroid Build Coastguard Worker output_pad_w=st.integers(0, 2), 431*da0073e9SAndroid Build Coastguard Worker dilation=st.integers(1, 2), 432*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 433*da0073e9SAndroid Build Coastguard Worker format=st.sampled_from( 434*da0073e9SAndroid Build Coastguard Worker [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 435*da0073e9SAndroid Build Coastguard Worker ), 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker def test_conv2d_transpose( 438*da0073e9SAndroid Build Coastguard Worker self, 439*da0073e9SAndroid Build Coastguard Worker batch_size, 440*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 441*da0073e9SAndroid Build Coastguard Worker height, 442*da0073e9SAndroid Build Coastguard Worker width, 443*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 444*da0073e9SAndroid Build Coastguard Worker groups, 445*da0073e9SAndroid Build Coastguard Worker kernel_h, 446*da0073e9SAndroid Build Coastguard Worker kernel_w, 447*da0073e9SAndroid Build Coastguard Worker stride_h, 448*da0073e9SAndroid Build Coastguard Worker stride_w, 449*da0073e9SAndroid Build Coastguard Worker pad_h, 450*da0073e9SAndroid Build Coastguard Worker pad_w, 451*da0073e9SAndroid Build Coastguard Worker output_pad_h, 452*da0073e9SAndroid Build Coastguard Worker output_pad_w, 453*da0073e9SAndroid Build Coastguard Worker dilation, 454*da0073e9SAndroid Build Coastguard Worker use_bias, 455*da0073e9SAndroid Build Coastguard Worker format, 456*da0073e9SAndroid Build Coastguard Worker ): 457*da0073e9SAndroid Build Coastguard Worker class Conv2DT(torch.nn.Module): 458*da0073e9SAndroid Build Coastguard Worker def __init__( 459*da0073e9SAndroid Build Coastguard Worker self, 460*da0073e9SAndroid Build Coastguard Worker weight, 461*da0073e9SAndroid Build Coastguard Worker bias, 462*da0073e9SAndroid Build Coastguard Worker strides, 463*da0073e9SAndroid Build Coastguard Worker paddings, 464*da0073e9SAndroid Build Coastguard Worker output_paddings, 465*da0073e9SAndroid Build Coastguard Worker dilations, 466*da0073e9SAndroid Build Coastguard Worker groups, 467*da0073e9SAndroid Build Coastguard Worker ): 468*da0073e9SAndroid Build Coastguard Worker super().__init__() 469*da0073e9SAndroid Build Coastguard Worker self.weight = weight 470*da0073e9SAndroid Build Coastguard Worker self.bias = bias 471*da0073e9SAndroid Build Coastguard Worker self.strides = strides 472*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 473*da0073e9SAndroid Build Coastguard Worker self.output_paddings = output_paddings 474*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 475*da0073e9SAndroid Build Coastguard Worker self.groups = groups 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 478*da0073e9SAndroid Build Coastguard Worker return F.conv_transpose2d( 479*da0073e9SAndroid Build Coastguard Worker x, 480*da0073e9SAndroid Build Coastguard Worker self.weight, 481*da0073e9SAndroid Build Coastguard Worker self.bias, 482*da0073e9SAndroid Build Coastguard Worker self.strides, 483*da0073e9SAndroid Build Coastguard Worker self.paddings, 484*da0073e9SAndroid Build Coastguard Worker self.output_paddings, 485*da0073e9SAndroid Build Coastguard Worker self.groups, 486*da0073e9SAndroid Build Coastguard Worker self.dilations, 487*da0073e9SAndroid Build Coastguard Worker ) 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker class Conv2DTPrePacked(torch.nn.Module): 490*da0073e9SAndroid Build Coastguard Worker def __init__( 491*da0073e9SAndroid Build Coastguard Worker self, 492*da0073e9SAndroid Build Coastguard Worker weight, 493*da0073e9SAndroid Build Coastguard Worker bias, 494*da0073e9SAndroid Build Coastguard Worker strides, 495*da0073e9SAndroid Build Coastguard Worker paddings, 496*da0073e9SAndroid Build Coastguard Worker output_paddings, 497*da0073e9SAndroid Build Coastguard Worker dilations, 498*da0073e9SAndroid Build Coastguard Worker groups, 499*da0073e9SAndroid Build Coastguard Worker ): 500*da0073e9SAndroid Build Coastguard Worker super().__init__() 501*da0073e9SAndroid Build Coastguard Worker self.packed_weight_bias = ( 502*da0073e9SAndroid Build Coastguard Worker torch.ops.prepacked.conv2d_transpose_clamp_prepack( 503*da0073e9SAndroid Build Coastguard Worker weight, 504*da0073e9SAndroid Build Coastguard Worker bias, 505*da0073e9SAndroid Build Coastguard Worker strides, 506*da0073e9SAndroid Build Coastguard Worker paddings, 507*da0073e9SAndroid Build Coastguard Worker output_paddings, 508*da0073e9SAndroid Build Coastguard Worker dilations, 509*da0073e9SAndroid Build Coastguard Worker groups, 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker ) 512*da0073e9SAndroid Build Coastguard Worker 513*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 514*da0073e9SAndroid Build Coastguard Worker return torch.ops.prepacked.conv2d_transpose_clamp_run( 515*da0073e9SAndroid Build Coastguard Worker x, self.packed_weight_bias 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 519*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 520*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 521*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 522*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 523*da0073e9SAndroid Build Coastguard Worker output_paddings = (output_pad_h, output_pad_w) 524*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 525*da0073e9SAndroid Build Coastguard Worker assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 526*da0073e9SAndroid Build Coastguard Worker assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 527*da0073e9SAndroid Build Coastguard Worker assume((output_pad_h < stride_h) and (output_pad_h < dilation)) 528*da0073e9SAndroid Build Coastguard Worker assume((output_pad_w < stride_w) and (output_pad_w < dilation)) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 531*da0073e9SAndroid Build Coastguard Worker if format is not None: 532*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 533*da0073e9SAndroid Build Coastguard Worker weight = torch.rand( 534*da0073e9SAndroid Build Coastguard Worker (input_channels, output_channels_per_group, kernel_h, kernel_w) 535*da0073e9SAndroid Build Coastguard Worker ) 536*da0073e9SAndroid Build Coastguard Worker bias = None 537*da0073e9SAndroid Build Coastguard Worker if use_bias: 538*da0073e9SAndroid Build Coastguard Worker bias = torch.rand(output_channels) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker scripted_conv2d = torch.jit.script( 541*da0073e9SAndroid Build Coastguard Worker Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups) 542*da0073e9SAndroid Build Coastguard Worker ) 543*da0073e9SAndroid Build Coastguard Worker scripted_conv2d_clamp_prepacked = torch.jit.script( 544*da0073e9SAndroid Build Coastguard Worker Conv2DTPrePacked( 545*da0073e9SAndroid Build Coastguard Worker weight, bias, strides, paddings, output_paddings, dilations, groups 546*da0073e9SAndroid Build Coastguard Worker ) 547*da0073e9SAndroid Build Coastguard Worker ) 548*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_conv2d(input_data) 549*da0073e9SAndroid Build Coastguard Worker xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) 550*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker # Serialize the modules and then deserialize 553*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 554*da0073e9SAndroid Build Coastguard Worker if format is not None: 555*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 556*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 557*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_conv2d, buffer) 558*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 559*da0073e9SAndroid Build Coastguard Worker deserialized_conv2d = torch.jit.load(buffer) 560*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 561*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) 562*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 563*da0073e9SAndroid Build Coastguard Worker deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) 564*da0073e9SAndroid Build Coastguard Worker ref_result = deserialized_conv2d(input_data) 565*da0073e9SAndroid Build Coastguard Worker xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) 566*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker @unittest.skip( 569*da0073e9SAndroid Build Coastguard Worker "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" 570*da0073e9SAndroid Build Coastguard Worker ) 571*da0073e9SAndroid Build Coastguard Worker @given( 572*da0073e9SAndroid Build Coastguard Worker batch_size=st.integers(0, 3), 573*da0073e9SAndroid Build Coastguard Worker input_channels_per_group=st.integers(1, 32), 574*da0073e9SAndroid Build Coastguard Worker height=st.integers(5, 64), 575*da0073e9SAndroid Build Coastguard Worker width=st.integers(5, 64), 576*da0073e9SAndroid Build Coastguard Worker output_channels_per_group=st.integers(1, 32), 577*da0073e9SAndroid Build Coastguard Worker groups=st.integers(1, 16), 578*da0073e9SAndroid Build Coastguard Worker kernel_h=st.integers(1, 7), 579*da0073e9SAndroid Build Coastguard Worker kernel_w=st.integers(1, 7), 580*da0073e9SAndroid Build Coastguard Worker stride_h=st.integers(1, 2), 581*da0073e9SAndroid Build Coastguard Worker stride_w=st.integers(1, 2), 582*da0073e9SAndroid Build Coastguard Worker pad_h=st.integers(0, 2), 583*da0073e9SAndroid Build Coastguard Worker pad_w=st.integers(0, 2), 584*da0073e9SAndroid Build Coastguard Worker dilation=st.integers(1, 2), 585*da0073e9SAndroid Build Coastguard Worker linear_weight_output_dim=st.integers(2, 64), 586*da0073e9SAndroid Build Coastguard Worker use_bias=st.booleans(), 587*da0073e9SAndroid Build Coastguard Worker format=st.sampled_from( 588*da0073e9SAndroid Build Coastguard Worker [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] 589*da0073e9SAndroid Build Coastguard Worker ), 590*da0073e9SAndroid Build Coastguard Worker ) 591*da0073e9SAndroid Build Coastguard Worker def test_combined_model( 592*da0073e9SAndroid Build Coastguard Worker self, 593*da0073e9SAndroid Build Coastguard Worker batch_size, 594*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 595*da0073e9SAndroid Build Coastguard Worker height, 596*da0073e9SAndroid Build Coastguard Worker width, 597*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 598*da0073e9SAndroid Build Coastguard Worker groups, 599*da0073e9SAndroid Build Coastguard Worker kernel_h, 600*da0073e9SAndroid Build Coastguard Worker kernel_w, 601*da0073e9SAndroid Build Coastguard Worker stride_h, 602*da0073e9SAndroid Build Coastguard Worker stride_w, 603*da0073e9SAndroid Build Coastguard Worker pad_h, 604*da0073e9SAndroid Build Coastguard Worker pad_w, 605*da0073e9SAndroid Build Coastguard Worker dilation, 606*da0073e9SAndroid Build Coastguard Worker linear_weight_output_dim, 607*da0073e9SAndroid Build Coastguard Worker use_bias, 608*da0073e9SAndroid Build Coastguard Worker format, 609*da0073e9SAndroid Build Coastguard Worker ): 610*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 611*da0073e9SAndroid Build Coastguard Worker def __init__( 612*da0073e9SAndroid Build Coastguard Worker self, 613*da0073e9SAndroid Build Coastguard Worker conv_weight, 614*da0073e9SAndroid Build Coastguard Worker conv_bias, 615*da0073e9SAndroid Build Coastguard Worker linear_weight, 616*da0073e9SAndroid Build Coastguard Worker linear_bias, 617*da0073e9SAndroid Build Coastguard Worker strides, 618*da0073e9SAndroid Build Coastguard Worker paddings, 619*da0073e9SAndroid Build Coastguard Worker dilations, 620*da0073e9SAndroid Build Coastguard Worker groups, 621*da0073e9SAndroid Build Coastguard Worker ): 622*da0073e9SAndroid Build Coastguard Worker super().__init__() 623*da0073e9SAndroid Build Coastguard Worker self.conv_weight = conv_weight 624*da0073e9SAndroid Build Coastguard Worker self.conv_bias = conv_bias 625*da0073e9SAndroid Build Coastguard Worker self.linear_weight = linear_weight 626*da0073e9SAndroid Build Coastguard Worker self.linear_bias = linear_bias 627*da0073e9SAndroid Build Coastguard Worker self.strides = strides 628*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 629*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 630*da0073e9SAndroid Build Coastguard Worker self.groups = groups 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 633*da0073e9SAndroid Build Coastguard Worker o = F.conv2d( 634*da0073e9SAndroid Build Coastguard Worker x, 635*da0073e9SAndroid Build Coastguard Worker self.conv_weight, 636*da0073e9SAndroid Build Coastguard Worker self.conv_bias, 637*da0073e9SAndroid Build Coastguard Worker self.strides, 638*da0073e9SAndroid Build Coastguard Worker self.paddings, 639*da0073e9SAndroid Build Coastguard Worker self.dilations, 640*da0073e9SAndroid Build Coastguard Worker self.groups, 641*da0073e9SAndroid Build Coastguard Worker ) 642*da0073e9SAndroid Build Coastguard Worker o = o.permute([0, 2, 3, 1]) 643*da0073e9SAndroid Build Coastguard Worker o = F.linear(o, self.linear_weight, self.linear_bias) 644*da0073e9SAndroid Build Coastguard Worker return F.relu(o) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker class MPrePacked(torch.nn.Module): 647*da0073e9SAndroid Build Coastguard Worker def __init__( 648*da0073e9SAndroid Build Coastguard Worker self, 649*da0073e9SAndroid Build Coastguard Worker conv_weight, 650*da0073e9SAndroid Build Coastguard Worker conv_bias, 651*da0073e9SAndroid Build Coastguard Worker linear_weight, 652*da0073e9SAndroid Build Coastguard Worker linear_bias, 653*da0073e9SAndroid Build Coastguard Worker strides, 654*da0073e9SAndroid Build Coastguard Worker paddings, 655*da0073e9SAndroid Build Coastguard Worker dilations, 656*da0073e9SAndroid Build Coastguard Worker groups, 657*da0073e9SAndroid Build Coastguard Worker ): 658*da0073e9SAndroid Build Coastguard Worker super().__init__() 659*da0073e9SAndroid Build Coastguard Worker self.conv2d_clamp_run_weight_bias = ( 660*da0073e9SAndroid Build Coastguard Worker torch.ops.prepacked.conv2d_clamp_prepack( 661*da0073e9SAndroid Build Coastguard Worker conv_weight, conv_bias, strides, paddings, dilations, groups 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker ) 664*da0073e9SAndroid Build Coastguard Worker self.linear_clamp_run_weight_bias = ( 665*da0073e9SAndroid Build Coastguard Worker torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias) 666*da0073e9SAndroid Build Coastguard Worker ) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 669*da0073e9SAndroid Build Coastguard Worker o = torch.ops.prepacked.conv2d_clamp_run( 670*da0073e9SAndroid Build Coastguard Worker x, self.conv2d_clamp_run_weight_bias 671*da0073e9SAndroid Build Coastguard Worker ) 672*da0073e9SAndroid Build Coastguard Worker o = o.permute([0, 2, 3, 1]) 673*da0073e9SAndroid Build Coastguard Worker o = torch.ops.prepacked.linear_clamp_run( 674*da0073e9SAndroid Build Coastguard Worker o, self.linear_clamp_run_weight_bias 675*da0073e9SAndroid Build Coastguard Worker ) 676*da0073e9SAndroid Build Coastguard Worker return F.relu(o) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 679*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 680*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 681*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 682*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 683*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 684*da0073e9SAndroid Build Coastguard Worker assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) 685*da0073e9SAndroid Build Coastguard Worker assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 688*da0073e9SAndroid Build Coastguard Worker if format is not None: 689*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=format) 690*da0073e9SAndroid Build Coastguard Worker conv_weight = torch.rand( 691*da0073e9SAndroid Build Coastguard Worker (output_channels, input_channels_per_group, kernel_h, kernel_w) 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker conv_bias = None 694*da0073e9SAndroid Build Coastguard Worker if use_bias: 695*da0073e9SAndroid Build Coastguard Worker conv_bias = torch.rand(output_channels) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker # This is done just to find the output shape of the result 698*da0073e9SAndroid Build Coastguard Worker # so that the shape of weight for the following linear layer 699*da0073e9SAndroid Build Coastguard Worker # can be determined. 700*da0073e9SAndroid Build Coastguard Worker result = F.conv2d( 701*da0073e9SAndroid Build Coastguard Worker input_data, conv_weight, conv_bias, strides, paddings, dilations, groups 702*da0073e9SAndroid Build Coastguard Worker ) 703*da0073e9SAndroid Build Coastguard Worker linear_input_shape = result.shape[1] 704*da0073e9SAndroid Build Coastguard Worker 705*da0073e9SAndroid Build Coastguard Worker linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape)) 706*da0073e9SAndroid Build Coastguard Worker linear_bias = None 707*da0073e9SAndroid Build Coastguard Worker if use_bias: 708*da0073e9SAndroid Build Coastguard Worker linear_bias = torch.rand(linear_weight_output_dim) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker scripted_m = torch.jit.script( 711*da0073e9SAndroid Build Coastguard Worker M( 712*da0073e9SAndroid Build Coastguard Worker conv_weight, 713*da0073e9SAndroid Build Coastguard Worker conv_bias, 714*da0073e9SAndroid Build Coastguard Worker linear_weight, 715*da0073e9SAndroid Build Coastguard Worker linear_bias, 716*da0073e9SAndroid Build Coastguard Worker strides, 717*da0073e9SAndroid Build Coastguard Worker paddings, 718*da0073e9SAndroid Build Coastguard Worker dilations, 719*da0073e9SAndroid Build Coastguard Worker groups, 720*da0073e9SAndroid Build Coastguard Worker ) 721*da0073e9SAndroid Build Coastguard Worker ) 722*da0073e9SAndroid Build Coastguard Worker scripted_m_prepacked = torch.jit.script( 723*da0073e9SAndroid Build Coastguard Worker MPrePacked( 724*da0073e9SAndroid Build Coastguard Worker conv_weight, 725*da0073e9SAndroid Build Coastguard Worker conv_bias, 726*da0073e9SAndroid Build Coastguard Worker linear_weight, 727*da0073e9SAndroid Build Coastguard Worker linear_bias, 728*da0073e9SAndroid Build Coastguard Worker strides, 729*da0073e9SAndroid Build Coastguard Worker paddings, 730*da0073e9SAndroid Build Coastguard Worker dilations, 731*da0073e9SAndroid Build Coastguard Worker groups, 732*da0073e9SAndroid Build Coastguard Worker ) 733*da0073e9SAndroid Build Coastguard Worker ) 734*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_m(input_data) 735*da0073e9SAndroid Build Coastguard Worker xnnpack_result = scripted_m_prepacked(input_data) 736*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker # Serialize the modules and then deserialize 739*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 740*da0073e9SAndroid Build Coastguard Worker input_data = input_data.contiguous(memory_format=torch.channels_last) 741*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 742*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_m, buffer) 743*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 744*da0073e9SAndroid Build Coastguard Worker deserialized_m = torch.jit.load(buffer) 745*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 746*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_m_prepacked, buffer) 747*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 748*da0073e9SAndroid Build Coastguard Worker deserialized_m_prepacked = torch.jit.load(buffer) 749*da0073e9SAndroid Build Coastguard Worker ref_result = deserialized_m(input_data) 750*da0073e9SAndroid Build Coastguard Worker xnnpack_result = deserialized_m_prepacked(input_data) 751*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless( 755*da0073e9SAndroid Build Coastguard Worker torch.backends.xnnpack.enabled, 756*da0073e9SAndroid Build Coastguard Worker " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 757*da0073e9SAndroid Build Coastguard Worker) 758*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 759*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 760*da0073e9SAndroid Build Coastguard Worker "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", 761*da0073e9SAndroid Build Coastguard Worker) 762*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKRewritePass(TestCase): 763*da0073e9SAndroid Build Coastguard Worker @staticmethod 764*da0073e9SAndroid Build Coastguard Worker def validate_transformed_module( 765*da0073e9SAndroid Build Coastguard Worker # To please flake 766*da0073e9SAndroid Build Coastguard Worker self, 767*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 768*da0073e9SAndroid Build Coastguard Worker data_shape, 769*da0073e9SAndroid Build Coastguard Worker prepack_removal=False, 770*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=False, 771*da0073e9SAndroid Build Coastguard Worker ): 772*da0073e9SAndroid Build Coastguard Worker input_data = torch.normal(1, 20, size=data_shape) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker for jit_method in ["script", "trace"]: 775*da0073e9SAndroid Build Coastguard Worker module_instance = self 776*da0073e9SAndroid Build Coastguard Worker if jit_method == "script": 777*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(module_instance) 778*da0073e9SAndroid Build Coastguard Worker else: 779*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.trace(module_instance, input_data) 780*da0073e9SAndroid Build Coastguard Worker scripted_model.eval() 781*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_model(input_data) 782*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_insert_prepacked_ops(scripted_model._c) 783*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops or prepack_removal: 784*da0073e9SAndroid Build Coastguard Worker scripted_model._c = torch._C._freeze_module(scripted_model._c) 785*da0073e9SAndroid Build Coastguard Worker if fuse_clamping_ops: 786*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c) 787*da0073e9SAndroid Build Coastguard Worker if prepack_removal: 788*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fold_prepacking_ops(scripted_model._c) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 791*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_model, buffer) 792*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 793*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model = torch.jit.load(buffer) 794*da0073e9SAndroid Build Coastguard Worker for pattern, v in pattern_count_map.items(): 795*da0073e9SAndroid Build Coastguard Worker if v == 0: 796*da0073e9SAndroid Build Coastguard Worker FileCheck().check(pattern).run(deserialized_scripted_model.graph) 797*da0073e9SAndroid Build Coastguard Worker elif v == -1: 798*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not(pattern).run( 799*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model.graph 800*da0073e9SAndroid Build Coastguard Worker ) 801*da0073e9SAndroid Build Coastguard Worker else: 802*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(pattern, v, exactly=True).run( 803*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model.graph 804*da0073e9SAndroid Build Coastguard Worker ) 805*da0073e9SAndroid Build Coastguard Worker xnnpack_result = deserialized_scripted_model(input_data) 806*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker def test_linear(self): 809*da0073e9SAndroid Build Coastguard Worker data_shape = [2, 3, 32] 810*da0073e9SAndroid Build Coastguard Worker weight_output_dim = 24 811*da0073e9SAndroid Build Coastguard Worker weight_shape = (weight_output_dim, data_shape[-1]) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker class Linear(torch.nn.Module): 814*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 815*da0073e9SAndroid Build Coastguard Worker super().__init__() 816*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 817*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_shape), requires_grad=False 818*da0073e9SAndroid Build Coastguard Worker ) 819*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 820*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 824*da0073e9SAndroid Build Coastguard Worker return F.linear(x, self.weight, self.bias) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker class LinearNoBias(torch.nn.Module): 827*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 828*da0073e9SAndroid Build Coastguard Worker super().__init__() 829*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 830*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_shape), requires_grad=False 831*da0073e9SAndroid Build Coastguard Worker ) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 834*da0073e9SAndroid Build Coastguard Worker return F.linear(x, self.weight, None) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker # Linear with bias pattern. 837*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 838*da0073e9SAndroid Build Coastguard Worker "Tensor = prim::CallFunction": -1, 839*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": 1, 840*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 841*da0073e9SAndroid Build Coastguard Worker } 842*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 843*da0073e9SAndroid Build Coastguard Worker Linear(), pattern_count_map, data_shape 844*da0073e9SAndroid Build Coastguard Worker ) 845*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 846*da0073e9SAndroid Build Coastguard Worker LinearNoBias(), pattern_count_map, data_shape 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker # Conv params 850*da0073e9SAndroid Build Coastguard Worker batch_size = 2 851*da0073e9SAndroid Build Coastguard Worker input_channels_per_group = 6 852*da0073e9SAndroid Build Coastguard Worker height = 16 853*da0073e9SAndroid Build Coastguard Worker width = 16 854*da0073e9SAndroid Build Coastguard Worker output_channels_per_group = 6 855*da0073e9SAndroid Build Coastguard Worker groups = 4 856*da0073e9SAndroid Build Coastguard Worker kernel_h = kernel_w = 3 857*da0073e9SAndroid Build Coastguard Worker stride_h = stride_w = 1 858*da0073e9SAndroid Build Coastguard Worker pad_h = pad_w = 1 859*da0073e9SAndroid Build Coastguard Worker output_pad_h = output_pad_w = 0 860*da0073e9SAndroid Build Coastguard Worker dilation = 1 861*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 862*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 863*da0073e9SAndroid Build Coastguard Worker kernels = (kernel_h, kernel_w) 864*da0073e9SAndroid Build Coastguard Worker strides = (stride_h, stride_w) 865*da0073e9SAndroid Build Coastguard Worker paddings = (pad_h, pad_w) 866*da0073e9SAndroid Build Coastguard Worker output_paddings = (output_pad_h, output_pad_w) 867*da0073e9SAndroid Build Coastguard Worker dilations = (dilation, dilation) 868*da0073e9SAndroid Build Coastguard Worker conv_weight_shape = ( 869*da0073e9SAndroid Build Coastguard Worker output_channels, 870*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 871*da0073e9SAndroid Build Coastguard Worker kernel_h, 872*da0073e9SAndroid Build Coastguard Worker kernel_w, 873*da0073e9SAndroid Build Coastguard Worker ) 874*da0073e9SAndroid Build Coastguard Worker conv_transpose_weight_shape = ( 875*da0073e9SAndroid Build Coastguard Worker input_channels, 876*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 877*da0073e9SAndroid Build Coastguard Worker kernel_h, 878*da0073e9SAndroid Build Coastguard Worker kernel_w, 879*da0073e9SAndroid Build Coastguard Worker ) 880*da0073e9SAndroid Build Coastguard Worker conv_bias_shape = output_channels 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker class Conv2D(torch.nn.Module): 883*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 884*da0073e9SAndroid Build Coastguard Worker super().__init__() 885*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 886*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_weight_shape), requires_grad=False 887*da0073e9SAndroid Build Coastguard Worker ) 888*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 889*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_bias_shape), requires_grad=False 890*da0073e9SAndroid Build Coastguard Worker ) 891*da0073e9SAndroid Build Coastguard Worker self.strides = strides 892*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 893*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 894*da0073e9SAndroid Build Coastguard Worker self.groups = groups 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 897*da0073e9SAndroid Build Coastguard Worker return F.conv2d( 898*da0073e9SAndroid Build Coastguard Worker x, 899*da0073e9SAndroid Build Coastguard Worker self.weight, 900*da0073e9SAndroid Build Coastguard Worker self.bias, 901*da0073e9SAndroid Build Coastguard Worker self.strides, 902*da0073e9SAndroid Build Coastguard Worker self.paddings, 903*da0073e9SAndroid Build Coastguard Worker self.dilations, 904*da0073e9SAndroid Build Coastguard Worker self.groups, 905*da0073e9SAndroid Build Coastguard Worker ) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker class Conv2DT(torch.nn.Module): 908*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 909*da0073e9SAndroid Build Coastguard Worker super().__init__() 910*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 911*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_transpose_weight_shape), requires_grad=False 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 914*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_bias_shape), requires_grad=False 915*da0073e9SAndroid Build Coastguard Worker ) 916*da0073e9SAndroid Build Coastguard Worker self.strides = strides 917*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 918*da0073e9SAndroid Build Coastguard Worker self.output_paddings = output_paddings 919*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 920*da0073e9SAndroid Build Coastguard Worker self.groups = groups 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 923*da0073e9SAndroid Build Coastguard Worker return F.conv_transpose2d( 924*da0073e9SAndroid Build Coastguard Worker x, 925*da0073e9SAndroid Build Coastguard Worker self.weight, 926*da0073e9SAndroid Build Coastguard Worker self.bias, 927*da0073e9SAndroid Build Coastguard Worker self.strides, 928*da0073e9SAndroid Build Coastguard Worker self.paddings, 929*da0073e9SAndroid Build Coastguard Worker self.output_paddings, 930*da0073e9SAndroid Build Coastguard Worker self.groups, 931*da0073e9SAndroid Build Coastguard Worker self.dilations, 932*da0073e9SAndroid Build Coastguard Worker ) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, height, width) 935*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 936*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": -1, 937*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": 1, 938*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 939*da0073e9SAndroid Build Coastguard Worker } 940*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 941*da0073e9SAndroid Build Coastguard Worker Conv2D(), pattern_count_map, data_shape 942*da0073e9SAndroid Build Coastguard Worker ) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker transpose_data_shape = (batch_size, input_channels, height, width) 945*da0073e9SAndroid Build Coastguard Worker transpose_pattern_count_map = { 946*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv_transpose2d": -1, 947*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_transpose_clamp_prepack": 1, 948*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_transpose_clamp_run": 1, 949*da0073e9SAndroid Build Coastguard Worker } 950*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 951*da0073e9SAndroid Build Coastguard Worker Conv2DT(), transpose_pattern_count_map, data_shape 952*da0073e9SAndroid Build Coastguard Worker ) 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker input_data = torch.rand((batch_size, input_channels, height, width)) 955*da0073e9SAndroid Build Coastguard Worker conv_weight = torch.rand( 956*da0073e9SAndroid Build Coastguard Worker (output_channels, input_channels_per_group, kernel_h, kernel_w) 957*da0073e9SAndroid Build Coastguard Worker ) 958*da0073e9SAndroid Build Coastguard Worker conv_bias = torch.rand(output_channels) 959*da0073e9SAndroid Build Coastguard Worker result = F.conv2d( 960*da0073e9SAndroid Build Coastguard Worker input_data, conv_weight, conv_bias, strides, paddings, dilations, groups 961*da0073e9SAndroid Build Coastguard Worker ) 962*da0073e9SAndroid Build Coastguard Worker linear_input_shape = result.shape[1] 963*da0073e9SAndroid Build Coastguard Worker linear_weight_shape = (weight_output_dim, linear_input_shape) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 966*da0073e9SAndroid Build Coastguard Worker def __init__(self, activation_fn=F.relu): 967*da0073e9SAndroid Build Coastguard Worker super().__init__() 968*da0073e9SAndroid Build Coastguard Worker self.conv_weight = torch.nn.Parameter( 969*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_weight_shape), requires_grad=False 970*da0073e9SAndroid Build Coastguard Worker ) 971*da0073e9SAndroid Build Coastguard Worker self.conv_bias = torch.nn.Parameter( 972*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_bias_shape), requires_grad=False 973*da0073e9SAndroid Build Coastguard Worker ) 974*da0073e9SAndroid Build Coastguard Worker self.linear_weight = torch.nn.Parameter( 975*da0073e9SAndroid Build Coastguard Worker torch.rand(linear_weight_shape), requires_grad=False 976*da0073e9SAndroid Build Coastguard Worker ) 977*da0073e9SAndroid Build Coastguard Worker self.linear_bias = torch.nn.Parameter( 978*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 979*da0073e9SAndroid Build Coastguard Worker ) 980*da0073e9SAndroid Build Coastguard Worker self.strides = strides 981*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 982*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 983*da0073e9SAndroid Build Coastguard Worker self.groups = groups 984*da0073e9SAndroid Build Coastguard Worker self.activation_fn = activation_fn 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 987*da0073e9SAndroid Build Coastguard Worker o = F.conv2d( 988*da0073e9SAndroid Build Coastguard Worker x, 989*da0073e9SAndroid Build Coastguard Worker self.conv_weight, 990*da0073e9SAndroid Build Coastguard Worker self.conv_bias, 991*da0073e9SAndroid Build Coastguard Worker self.strides, 992*da0073e9SAndroid Build Coastguard Worker self.paddings, 993*da0073e9SAndroid Build Coastguard Worker self.dilations, 994*da0073e9SAndroid Build Coastguard Worker self.groups, 995*da0073e9SAndroid Build Coastguard Worker ) 996*da0073e9SAndroid Build Coastguard Worker o = self.activation_fn(o) 997*da0073e9SAndroid Build Coastguard Worker o = o.permute([0, 2, 3, 1]) 998*da0073e9SAndroid Build Coastguard Worker o = F.linear(o, self.linear_weight, self.linear_bias) 999*da0073e9SAndroid Build Coastguard Worker return self.activation_fn(o) 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1002*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": -1, 1003*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": 1, 1004*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1005*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": 1, 1006*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1007*da0073e9SAndroid Build Coastguard Worker } 1008*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1009*da0073e9SAndroid Build Coastguard Worker M(), pattern_count_map, data_shape 1010*da0073e9SAndroid Build Coastguard Worker ) 1011*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1012*da0073e9SAndroid Build Coastguard Worker pattern_count_map["Tensor = prim::CallFunction"] = -1 1013*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1014*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1015*da0073e9SAndroid Build Coastguard Worker M(), pattern_count_map, data_shape, prepack_removal=True 1016*da0073e9SAndroid Build Coastguard Worker ) 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # Not inplace relu fusion test. 1019*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1020*da0073e9SAndroid Build Coastguard Worker "aten::relu": 2, 1021*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1022*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1023*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1024*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1025*da0073e9SAndroid Build Coastguard Worker } 1026*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1027*da0073e9SAndroid Build Coastguard Worker M(), pattern_count_map, data_shape, prepack_removal=True 1028*da0073e9SAndroid Build Coastguard Worker ) 1029*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1030*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1031*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = -1 1032*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1033*da0073e9SAndroid Build Coastguard Worker M(), 1034*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1035*da0073e9SAndroid Build Coastguard Worker data_shape, 1036*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1037*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1038*da0073e9SAndroid Build Coastguard Worker ) 1039*da0073e9SAndroid Build Coastguard Worker 1040*da0073e9SAndroid Build Coastguard Worker # Inplace relu fusion test. 1041*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1042*da0073e9SAndroid Build Coastguard Worker "aten::relu": 2, 1043*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1044*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1045*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1046*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1047*da0073e9SAndroid Build Coastguard Worker } 1048*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1049*da0073e9SAndroid Build Coastguard Worker M(F.relu_), pattern_count_map, data_shape, prepack_removal=True 1050*da0073e9SAndroid Build Coastguard Worker ) 1051*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1052*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1053*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::relu"] = -1 1054*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1055*da0073e9SAndroid Build Coastguard Worker M(F.relu_), 1056*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1057*da0073e9SAndroid Build Coastguard Worker data_shape, 1058*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1059*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1060*da0073e9SAndroid Build Coastguard Worker ) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker # Not inplace hardtanh fusion test. 1063*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1064*da0073e9SAndroid Build Coastguard Worker "aten::hardtanh": 2, 1065*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1066*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1067*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1068*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1069*da0073e9SAndroid Build Coastguard Worker } 1070*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1071*da0073e9SAndroid Build Coastguard Worker M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True 1072*da0073e9SAndroid Build Coastguard Worker ) 1073*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1074*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1075*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh"] = -1 1076*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1077*da0073e9SAndroid Build Coastguard Worker M(F.hardtanh), 1078*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1079*da0073e9SAndroid Build Coastguard Worker data_shape, 1080*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1081*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1082*da0073e9SAndroid Build Coastguard Worker ) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker # Inplace hardtanh fusion test. 1085*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1086*da0073e9SAndroid Build Coastguard Worker "aten::hardtanh_": 2, 1087*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1088*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1089*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1090*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1091*da0073e9SAndroid Build Coastguard Worker } 1092*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1093*da0073e9SAndroid Build Coastguard Worker M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True 1094*da0073e9SAndroid Build Coastguard Worker ) 1095*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 1096*da0073e9SAndroid Build Coastguard Worker pattern_count_map["prepacked::linear_clamp_prepack"] = -1 1097*da0073e9SAndroid Build Coastguard Worker pattern_count_map["aten::hardtanh_"] = -1 1098*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1099*da0073e9SAndroid Build Coastguard Worker M(F.hardtanh_), 1100*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1101*da0073e9SAndroid Build Coastguard Worker data_shape, 1102*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1103*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1104*da0073e9SAndroid Build Coastguard Worker ) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker class MFusionAntiPattern(torch.nn.Module): 1107*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1108*da0073e9SAndroid Build Coastguard Worker super().__init__() 1109*da0073e9SAndroid Build Coastguard Worker self.linear_weight = torch.nn.Parameter( 1110*da0073e9SAndroid Build Coastguard Worker torch.rand(linear_weight_shape), requires_grad=False 1111*da0073e9SAndroid Build Coastguard Worker ) 1112*da0073e9SAndroid Build Coastguard Worker self.linear_bias = torch.nn.Parameter( 1113*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 1114*da0073e9SAndroid Build Coastguard Worker ) 1115*da0073e9SAndroid Build Coastguard Worker self.strides = strides 1116*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 1117*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 1118*da0073e9SAndroid Build Coastguard Worker self.groups = groups 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1121*da0073e9SAndroid Build Coastguard Worker o = F.linear(x, self.linear_weight, self.linear_bias) 1122*da0073e9SAndroid Build Coastguard Worker o = F.relu(o) 1123*da0073e9SAndroid Build Coastguard Worker o = F.hardtanh(o) 1124*da0073e9SAndroid Build Coastguard Worker return o 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker # Unfusable hardtanh. 1127*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1128*da0073e9SAndroid Build Coastguard Worker "aten::hardtanh": 1, # hardtanh cannot be. 1129*da0073e9SAndroid Build Coastguard Worker "aten::relu": -1, # relu is fused. 1130*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1131*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1132*da0073e9SAndroid Build Coastguard Worker } 1133*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1134*da0073e9SAndroid Build Coastguard Worker MFusionAntiPattern(), 1135*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1136*da0073e9SAndroid Build Coastguard Worker (16, linear_weight_shape[1]), 1137*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1138*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1139*da0073e9SAndroid Build Coastguard Worker ) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker class MFusionAntiPatternParamMinMax(torch.nn.Module): 1142*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1143*da0073e9SAndroid Build Coastguard Worker super().__init__() 1144*da0073e9SAndroid Build Coastguard Worker self.linear_weight = torch.nn.Parameter( 1145*da0073e9SAndroid Build Coastguard Worker torch.rand(linear_weight_shape), requires_grad=False 1146*da0073e9SAndroid Build Coastguard Worker ) 1147*da0073e9SAndroid Build Coastguard Worker self.linear_bias = torch.nn.Parameter( 1148*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 1149*da0073e9SAndroid Build Coastguard Worker ) 1150*da0073e9SAndroid Build Coastguard Worker self.strides = strides 1151*da0073e9SAndroid Build Coastguard Worker self.paddings = paddings 1152*da0073e9SAndroid Build Coastguard Worker self.dilations = dilations 1153*da0073e9SAndroid Build Coastguard Worker self.groups = groups 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1156*da0073e9SAndroid Build Coastguard Worker min = x[0, 0] 1157*da0073e9SAndroid Build Coastguard Worker max = min + 10 1158*da0073e9SAndroid Build Coastguard Worker o = F.linear(x, self.linear_weight, self.linear_bias) 1159*da0073e9SAndroid Build Coastguard Worker o = F.hardtanh(o, min, max) 1160*da0073e9SAndroid Build Coastguard Worker return o 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker # Unfusable hardtanh. 1163*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1164*da0073e9SAndroid Build Coastguard Worker "aten::hardtanh": 1, # hardtanh cannot be. 1165*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": -1, 1166*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1167*da0073e9SAndroid Build Coastguard Worker } 1168*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1169*da0073e9SAndroid Build Coastguard Worker MFusionAntiPatternParamMinMax(), 1170*da0073e9SAndroid Build Coastguard Worker pattern_count_map, 1171*da0073e9SAndroid Build Coastguard Worker (16, linear_weight_shape[1]), 1172*da0073e9SAndroid Build Coastguard Worker prepack_removal=True, 1173*da0073e9SAndroid Build Coastguard Worker fuse_clamping_ops=True, 1174*da0073e9SAndroid Build Coastguard Worker ) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker def test_decomposed_linear(self): 1177*da0073e9SAndroid Build Coastguard Worker data_shape = [2, 32] 1178*da0073e9SAndroid Build Coastguard Worker weight_output_dim = 24 1179*da0073e9SAndroid Build Coastguard Worker weight_shape = (weight_output_dim, data_shape[-1]) 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker class DecomposedLinearAddmm(torch.nn.Module): 1182*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1183*da0073e9SAndroid Build Coastguard Worker super().__init__() 1184*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 1185*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_shape), requires_grad=False 1186*da0073e9SAndroid Build Coastguard Worker ) 1187*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 1188*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 1189*da0073e9SAndroid Build Coastguard Worker ) 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1192*da0073e9SAndroid Build Coastguard Worker weight_t = self.weight.t() 1193*da0073e9SAndroid Build Coastguard Worker return torch.addmm(self.bias, x, weight_t) 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker class DecomposedLinearMatmulAdd(torch.nn.Module): 1196*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1197*da0073e9SAndroid Build Coastguard Worker super().__init__() 1198*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 1199*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_shape), requires_grad=False 1200*da0073e9SAndroid Build Coastguard Worker ) 1201*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 1202*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 1203*da0073e9SAndroid Build Coastguard Worker ) 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1206*da0073e9SAndroid Build Coastguard Worker weight_t = self.weight.t() 1207*da0073e9SAndroid Build Coastguard Worker y = torch.matmul(x, weight_t) 1208*da0073e9SAndroid Build Coastguard Worker res = y.add_(self.bias) 1209*da0073e9SAndroid Build Coastguard Worker return res 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker class DecomposedLinearMatmul(torch.nn.Module): 1212*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1213*da0073e9SAndroid Build Coastguard Worker super().__init__() 1214*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 1215*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_shape), requires_grad=False 1216*da0073e9SAndroid Build Coastguard Worker ) 1217*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 1218*da0073e9SAndroid Build Coastguard Worker torch.rand(weight_output_dim), requires_grad=False 1219*da0073e9SAndroid Build Coastguard Worker ) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1222*da0073e9SAndroid Build Coastguard Worker weight_t = self.weight.t() 1223*da0073e9SAndroid Build Coastguard Worker res = torch.matmul(x, weight_t) 1224*da0073e9SAndroid Build Coastguard Worker return res 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker # Linear with bias pattern. 1227*da0073e9SAndroid Build Coastguard Worker pattern_count_map = { 1228*da0073e9SAndroid Build Coastguard Worker "Tensor = prim::CallFunction": -1, 1229*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_prepack": 1, 1230*da0073e9SAndroid Build Coastguard Worker "prepacked::linear_clamp_run": 1, 1231*da0073e9SAndroid Build Coastguard Worker } 1232*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1233*da0073e9SAndroid Build Coastguard Worker DecomposedLinearAddmm(), pattern_count_map, data_shape 1234*da0073e9SAndroid Build Coastguard Worker ) 1235*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1236*da0073e9SAndroid Build Coastguard Worker DecomposedLinearMatmulAdd(), pattern_count_map, data_shape 1237*da0073e9SAndroid Build Coastguard Worker ) 1238*da0073e9SAndroid Build Coastguard Worker TestXNNPACKRewritePass.validate_transformed_module( 1239*da0073e9SAndroid Build Coastguard Worker DecomposedLinearMatmul(), pattern_count_map, data_shape 1240*da0073e9SAndroid Build Coastguard Worker ) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless( 1244*da0073e9SAndroid Build Coastguard Worker torch.backends.xnnpack.enabled, 1245*da0073e9SAndroid Build Coastguard Worker " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", 1246*da0073e9SAndroid Build Coastguard Worker) 1247*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf( 1248*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TSAN, 1249*da0073e9SAndroid Build Coastguard Worker "TSAN is not fork-safe since we're forking in a multi-threaded environment", 1250*da0073e9SAndroid Build Coastguard Worker) 1251*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKConv1dTransformPass(TestCase): 1252*da0073e9SAndroid Build Coastguard Worker @staticmethod 1253*da0073e9SAndroid Build Coastguard Worker def validate_transform_conv1d_to_conv2d( 1254*da0073e9SAndroid Build Coastguard Worker self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape 1255*da0073e9SAndroid Build Coastguard Worker ): 1256*da0073e9SAndroid Build Coastguard Worker input_data = torch.normal(1, 20, size=data_shape) 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker for jit_method in ["script", "trace"]: 1259*da0073e9SAndroid Build Coastguard Worker module_instance = self 1260*da0073e9SAndroid Build Coastguard Worker if jit_method == "script": 1261*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(module_instance) 1262*da0073e9SAndroid Build Coastguard Worker else: 1263*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.trace(module_instance, input_data) 1264*da0073e9SAndroid Build Coastguard Worker scripted_model.eval() 1265*da0073e9SAndroid Build Coastguard Worker ref_result = scripted_model(input_data) 1266*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) 1267*da0073e9SAndroid Build Coastguard Worker optimized_scripted_model = optimize_for_mobile(scripted_model) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 1270*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted_model, buffer) 1271*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 1272*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model = torch.jit.load(buffer) 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker for pattern, v in pattern_count_transformed_map.items(): 1275*da0073e9SAndroid Build Coastguard Worker if v == 0: 1276*da0073e9SAndroid Build Coastguard Worker FileCheck().check(pattern).run(deserialized_scripted_model.graph) 1277*da0073e9SAndroid Build Coastguard Worker elif v == -1: 1278*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not(pattern).run( 1279*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model.graph 1280*da0073e9SAndroid Build Coastguard Worker ) 1281*da0073e9SAndroid Build Coastguard Worker else: 1282*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(pattern, v, exactly=True).run( 1283*da0073e9SAndroid Build Coastguard Worker deserialized_scripted_model.graph 1284*da0073e9SAndroid Build Coastguard Worker ) 1285*da0073e9SAndroid Build Coastguard Worker transformed_result = deserialized_scripted_model(input_data) 1286*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 1287*da0073e9SAndroid Build Coastguard Worker ref_result, transformed_result, rtol=1e-2, atol=1e-3 1288*da0073e9SAndroid Build Coastguard Worker ) 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker optimized_buffer = io.BytesIO() 1291*da0073e9SAndroid Build Coastguard Worker torch.jit.save(optimized_scripted_model, optimized_buffer) 1292*da0073e9SAndroid Build Coastguard Worker optimized_buffer.seek(0) 1293*da0073e9SAndroid Build Coastguard Worker deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer) 1294*da0073e9SAndroid Build Coastguard Worker 1295*da0073e9SAndroid Build Coastguard Worker for pattern, v in pattern_count_optimized_map.items(): 1296*da0073e9SAndroid Build Coastguard Worker if v == 0: 1297*da0073e9SAndroid Build Coastguard Worker FileCheck().check(pattern).run( 1298*da0073e9SAndroid Build Coastguard Worker deserialized_optimized_scripted_model.graph 1299*da0073e9SAndroid Build Coastguard Worker ) 1300*da0073e9SAndroid Build Coastguard Worker elif v == -1: 1301*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not(pattern).run( 1302*da0073e9SAndroid Build Coastguard Worker deserialized_optimized_scripted_model.graph 1303*da0073e9SAndroid Build Coastguard Worker ) 1304*da0073e9SAndroid Build Coastguard Worker else: 1305*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(pattern, v, exactly=True).run( 1306*da0073e9SAndroid Build Coastguard Worker deserialized_optimized_scripted_model.graph 1307*da0073e9SAndroid Build Coastguard Worker ) 1308*da0073e9SAndroid Build Coastguard Worker xnnpack_result = deserialized_optimized_scripted_model(input_data) 1309*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) 1310*da0073e9SAndroid Build Coastguard Worker 1311*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "T137513244") 1312*da0073e9SAndroid Build Coastguard Worker def test_conv1d_basic(self): 1313*da0073e9SAndroid Build Coastguard Worker batch_size_list = range(1, 3) 1314*da0073e9SAndroid Build Coastguard Worker input_channels_per_group_list = range(10, 12) 1315*da0073e9SAndroid Build Coastguard Worker width_list = range(10, 12) 1316*da0073e9SAndroid Build Coastguard Worker output_channels_per_group_list = range(10, 12) 1317*da0073e9SAndroid Build Coastguard Worker groups_list = range(1, 3) 1318*da0073e9SAndroid Build Coastguard Worker kernel_list = range(1, 4) 1319*da0073e9SAndroid Build Coastguard Worker stride_list = range(1, 3) 1320*da0073e9SAndroid Build Coastguard Worker padding_list = range(0, 3) 1321*da0073e9SAndroid Build Coastguard Worker dilation_list = range(1, 3) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker for hparams in itertools.product( 1324*da0073e9SAndroid Build Coastguard Worker batch_size_list, 1325*da0073e9SAndroid Build Coastguard Worker input_channels_per_group_list, 1326*da0073e9SAndroid Build Coastguard Worker width_list, 1327*da0073e9SAndroid Build Coastguard Worker output_channels_per_group_list, 1328*da0073e9SAndroid Build Coastguard Worker groups_list, 1329*da0073e9SAndroid Build Coastguard Worker kernel_list, 1330*da0073e9SAndroid Build Coastguard Worker stride_list, 1331*da0073e9SAndroid Build Coastguard Worker padding_list, 1332*da0073e9SAndroid Build Coastguard Worker dilation_list, 1333*da0073e9SAndroid Build Coastguard Worker ): 1334*da0073e9SAndroid Build Coastguard Worker ( 1335*da0073e9SAndroid Build Coastguard Worker batch_size, 1336*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 1337*da0073e9SAndroid Build Coastguard Worker width, 1338*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 1339*da0073e9SAndroid Build Coastguard Worker groups, 1340*da0073e9SAndroid Build Coastguard Worker kernel, 1341*da0073e9SAndroid Build Coastguard Worker stride, 1342*da0073e9SAndroid Build Coastguard Worker padding, 1343*da0073e9SAndroid Build Coastguard Worker dilation, 1344*da0073e9SAndroid Build Coastguard Worker ) = hparams 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 1347*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 1348*da0073e9SAndroid Build Coastguard Worker conv_weight_shape = (output_channels, input_channels_per_group, kernel) 1349*da0073e9SAndroid Build Coastguard Worker conv_bias_shape = output_channels 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker class Conv1D(torch.nn.Module): 1352*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1353*da0073e9SAndroid Build Coastguard Worker super().__init__() 1354*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter( 1355*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_weight_shape), requires_grad=False 1356*da0073e9SAndroid Build Coastguard Worker ) 1357*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter( 1358*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_bias_shape), requires_grad=False 1359*da0073e9SAndroid Build Coastguard Worker ) 1360*da0073e9SAndroid Build Coastguard Worker self.stride = stride 1361*da0073e9SAndroid Build Coastguard Worker self.padding = padding 1362*da0073e9SAndroid Build Coastguard Worker self.dilation = dilation 1363*da0073e9SAndroid Build Coastguard Worker self.groups = groups 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1366*da0073e9SAndroid Build Coastguard Worker return F.conv1d( 1367*da0073e9SAndroid Build Coastguard Worker x, 1368*da0073e9SAndroid Build Coastguard Worker self.weight, 1369*da0073e9SAndroid Build Coastguard Worker self.bias, 1370*da0073e9SAndroid Build Coastguard Worker self.stride, 1371*da0073e9SAndroid Build Coastguard Worker self.padding, 1372*da0073e9SAndroid Build Coastguard Worker self.dilation, 1373*da0073e9SAndroid Build Coastguard Worker self.groups, 1374*da0073e9SAndroid Build Coastguard Worker ) 1375*da0073e9SAndroid Build Coastguard Worker 1376*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, width) 1377*da0073e9SAndroid Build Coastguard Worker pattern_count_transformed_map = { 1378*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv1d": -1, 1379*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": 1, 1380*da0073e9SAndroid Build Coastguard Worker } 1381*da0073e9SAndroid Build Coastguard Worker pattern_count_optimized_map = { 1382*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv1d": -1, 1383*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": -1, 1384*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1385*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1386*da0073e9SAndroid Build Coastguard Worker } 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( 1389*da0073e9SAndroid Build Coastguard Worker Conv1D(), 1390*da0073e9SAndroid Build Coastguard Worker pattern_count_transformed_map, 1391*da0073e9SAndroid Build Coastguard Worker pattern_count_optimized_map, 1392*da0073e9SAndroid Build Coastguard Worker data_shape, 1393*da0073e9SAndroid Build Coastguard Worker ) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/46066 1396*da0073e9SAndroid Build Coastguard Worker @slowTest 1397*da0073e9SAndroid Build Coastguard Worker def test_conv1d_with_relu_fc(self): 1398*da0073e9SAndroid Build Coastguard Worker batch_size_list = range(1, 3) 1399*da0073e9SAndroid Build Coastguard Worker input_channels_per_group_list = range(10, 12) 1400*da0073e9SAndroid Build Coastguard Worker width_list = range(10, 12) 1401*da0073e9SAndroid Build Coastguard Worker output_channels_per_group_list = range(10, 12) 1402*da0073e9SAndroid Build Coastguard Worker groups_list = range(1, 3) 1403*da0073e9SAndroid Build Coastguard Worker kernel_list = range(1, 4) 1404*da0073e9SAndroid Build Coastguard Worker stride_list = range(1, 3) 1405*da0073e9SAndroid Build Coastguard Worker padding_list = range(0, 3) 1406*da0073e9SAndroid Build Coastguard Worker dilation_list = range(1, 3) 1407*da0073e9SAndroid Build Coastguard Worker output_features_list = range(1, 3) 1408*da0073e9SAndroid Build Coastguard Worker 1409*da0073e9SAndroid Build Coastguard Worker for hparams in itertools.product( 1410*da0073e9SAndroid Build Coastguard Worker batch_size_list, 1411*da0073e9SAndroid Build Coastguard Worker input_channels_per_group_list, 1412*da0073e9SAndroid Build Coastguard Worker width_list, 1413*da0073e9SAndroid Build Coastguard Worker output_channels_per_group_list, 1414*da0073e9SAndroid Build Coastguard Worker groups_list, 1415*da0073e9SAndroid Build Coastguard Worker kernel_list, 1416*da0073e9SAndroid Build Coastguard Worker stride_list, 1417*da0073e9SAndroid Build Coastguard Worker padding_list, 1418*da0073e9SAndroid Build Coastguard Worker dilation_list, 1419*da0073e9SAndroid Build Coastguard Worker output_features_list, 1420*da0073e9SAndroid Build Coastguard Worker ): 1421*da0073e9SAndroid Build Coastguard Worker ( 1422*da0073e9SAndroid Build Coastguard Worker batch_size, 1423*da0073e9SAndroid Build Coastguard Worker input_channels_per_group, 1424*da0073e9SAndroid Build Coastguard Worker width, 1425*da0073e9SAndroid Build Coastguard Worker output_channels_per_group, 1426*da0073e9SAndroid Build Coastguard Worker groups, 1427*da0073e9SAndroid Build Coastguard Worker kernel, 1428*da0073e9SAndroid Build Coastguard Worker stride, 1429*da0073e9SAndroid Build Coastguard Worker padding, 1430*da0073e9SAndroid Build Coastguard Worker dilation, 1431*da0073e9SAndroid Build Coastguard Worker output_features, 1432*da0073e9SAndroid Build Coastguard Worker ) = hparams 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker input_channels = input_channels_per_group * groups 1435*da0073e9SAndroid Build Coastguard Worker output_channels = output_channels_per_group * groups 1436*da0073e9SAndroid Build Coastguard Worker conv_weight_shape = (output_channels, input_channels_per_group, kernel) 1437*da0073e9SAndroid Build Coastguard Worker conv_bias_shape = output_channels 1438*da0073e9SAndroid Build Coastguard Worker conv_output_width = ( 1439*da0073e9SAndroid Build Coastguard Worker int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1 1440*da0073e9SAndroid Build Coastguard Worker ) 1441*da0073e9SAndroid Build Coastguard Worker fc_weight_shape = (output_features, output_channels * conv_output_width) 1442*da0073e9SAndroid Build Coastguard Worker fc_bias_shape = output_features 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker class Net(torch.nn.Module): 1445*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1446*da0073e9SAndroid Build Coastguard Worker super().__init__() 1447*da0073e9SAndroid Build Coastguard Worker self.conv_weight = torch.nn.Parameter( 1448*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_weight_shape), requires_grad=False 1449*da0073e9SAndroid Build Coastguard Worker ) 1450*da0073e9SAndroid Build Coastguard Worker self.conv_bias = torch.nn.Parameter( 1451*da0073e9SAndroid Build Coastguard Worker torch.rand(conv_bias_shape), requires_grad=False 1452*da0073e9SAndroid Build Coastguard Worker ) 1453*da0073e9SAndroid Build Coastguard Worker self.stride = stride 1454*da0073e9SAndroid Build Coastguard Worker self.padding = padding 1455*da0073e9SAndroid Build Coastguard Worker self.dilation = dilation 1456*da0073e9SAndroid Build Coastguard Worker self.groups = groups 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker self.fc_weight = torch.nn.Parameter( 1459*da0073e9SAndroid Build Coastguard Worker torch.rand(fc_weight_shape), requires_grad=False 1460*da0073e9SAndroid Build Coastguard Worker ) 1461*da0073e9SAndroid Build Coastguard Worker self.fc_bias = torch.nn.Parameter( 1462*da0073e9SAndroid Build Coastguard Worker torch.rand(fc_bias_shape), requires_grad=False 1463*da0073e9SAndroid Build Coastguard Worker ) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1466*da0073e9SAndroid Build Coastguard Worker x = F.conv1d( 1467*da0073e9SAndroid Build Coastguard Worker x, 1468*da0073e9SAndroid Build Coastguard Worker self.conv_weight, 1469*da0073e9SAndroid Build Coastguard Worker self.conv_bias, 1470*da0073e9SAndroid Build Coastguard Worker self.stride, 1471*da0073e9SAndroid Build Coastguard Worker self.padding, 1472*da0073e9SAndroid Build Coastguard Worker self.dilation, 1473*da0073e9SAndroid Build Coastguard Worker self.groups, 1474*da0073e9SAndroid Build Coastguard Worker ) 1475*da0073e9SAndroid Build Coastguard Worker x = F.relu(x) 1476*da0073e9SAndroid Build Coastguard Worker x = x.view(x.size(0), -1) 1477*da0073e9SAndroid Build Coastguard Worker x = F.linear(x, self.fc_weight, self.fc_bias) 1478*da0073e9SAndroid Build Coastguard Worker return x 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker data_shape = (batch_size, input_channels, width) 1481*da0073e9SAndroid Build Coastguard Worker pattern_count_transformed_map = { 1482*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv1d": -1, 1483*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": 1, 1484*da0073e9SAndroid Build Coastguard Worker } 1485*da0073e9SAndroid Build Coastguard Worker pattern_count_optimized_map = { 1486*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv1d": -1, 1487*da0073e9SAndroid Build Coastguard Worker "Tensor = aten::conv2d": -1, 1488*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_prepack": -1, 1489*da0073e9SAndroid Build Coastguard Worker "prepacked::conv2d_clamp_run": 1, 1490*da0073e9SAndroid Build Coastguard Worker } 1491*da0073e9SAndroid Build Coastguard Worker TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( 1492*da0073e9SAndroid Build Coastguard Worker Net(), 1493*da0073e9SAndroid Build Coastguard Worker pattern_count_transformed_map, 1494*da0073e9SAndroid Build Coastguard Worker pattern_count_optimized_map, 1495*da0073e9SAndroid Build Coastguard Worker data_shape, 1496*da0073e9SAndroid Build Coastguard Worker ) 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1500*da0073e9SAndroid Build Coastguard Worker run_tests() 1501