xref: /aosp_15_r20/external/pytorch/test/test_xnnpack_integration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom hypothesis import assume, given, strategies as st
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.backends.xnnpack
11*da0073e9SAndroid Build Coastguard Workerimport torch.testing._internal.hypothesis_utils as hu
12*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
15*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
16*da0073e9SAndroid Build Coastguard Worker    run_tests,
17*da0073e9SAndroid Build Coastguard Worker    slowTest,
18*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
19*da0073e9SAndroid Build Coastguard Worker    TestCase,
20*da0073e9SAndroid Build Coastguard Worker)
21*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.mobile_optimizer import optimize_for_mobile
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(
25*da0073e9SAndroid Build Coastguard Worker    torch.backends.xnnpack.enabled,
26*da0073e9SAndroid Build Coastguard Worker    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
29*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
30*da0073e9SAndroid Build Coastguard Worker    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKOps(TestCase):
33*da0073e9SAndroid Build Coastguard Worker    @unittest.skip(
34*da0073e9SAndroid Build Coastguard Worker        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
35*da0073e9SAndroid Build Coastguard Worker    )
36*da0073e9SAndroid Build Coastguard Worker    @given(
37*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
38*da0073e9SAndroid Build Coastguard Worker        data_shape=hu.array_shapes(1, 3, 2, 64),
39*da0073e9SAndroid Build Coastguard Worker        weight_output_dim=st.integers(2, 64),
40*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
41*da0073e9SAndroid Build Coastguard Worker    )
42*da0073e9SAndroid Build Coastguard Worker    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
43*da0073e9SAndroid Build Coastguard Worker        data_shape = [batch_size] + list(data_shape)
44*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand(data_shape)
45*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand((weight_output_dim, data_shape[-1]))
46*da0073e9SAndroid Build Coastguard Worker        if use_bias:
47*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(weight_output_dim)
48*da0073e9SAndroid Build Coastguard Worker        else:
49*da0073e9SAndroid Build Coastguard Worker            bias = None
50*da0073e9SAndroid Build Coastguard Worker        ref_result = F.linear(input_data, weight, bias)
51*da0073e9SAndroid Build Coastguard Worker        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
52*da0073e9SAndroid Build Coastguard Worker        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
53*da0073e9SAndroid Build Coastguard Worker            input_data, packed_weight_bias
54*da0073e9SAndroid Build Coastguard Worker        )
55*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
56*da0073e9SAndroid Build Coastguard Worker            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    @given(
60*da0073e9SAndroid Build Coastguard Worker        input_size=st.integers(2, 32),
61*da0073e9SAndroid Build Coastguard Worker        weight_output_dim=st.integers(2, 64),
62*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
63*da0073e9SAndroid Build Coastguard Worker    )
64*da0073e9SAndroid Build Coastguard Worker    def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
65*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand(input_size)
66*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand((weight_output_dim, input_data.shape[-1]))
67*da0073e9SAndroid Build Coastguard Worker        if use_bias:
68*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(weight_output_dim)
69*da0073e9SAndroid Build Coastguard Worker        else:
70*da0073e9SAndroid Build Coastguard Worker            bias = None
71*da0073e9SAndroid Build Coastguard Worker        ref_result = F.linear(input_data, weight, bias)
72*da0073e9SAndroid Build Coastguard Worker        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
73*da0073e9SAndroid Build Coastguard Worker        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
74*da0073e9SAndroid Build Coastguard Worker            input_data, packed_weight_bias
75*da0073e9SAndroid Build Coastguard Worker        )
76*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
77*da0073e9SAndroid Build Coastguard Worker            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
78*da0073e9SAndroid Build Coastguard Worker        )
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    @given(
81*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
82*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group=st.integers(1, 32),
83*da0073e9SAndroid Build Coastguard Worker        height=st.integers(5, 64),
84*da0073e9SAndroid Build Coastguard Worker        width=st.integers(5, 64),
85*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group=st.integers(1, 32),
86*da0073e9SAndroid Build Coastguard Worker        groups=st.integers(1, 16),
87*da0073e9SAndroid Build Coastguard Worker        kernel_h=st.integers(1, 7),
88*da0073e9SAndroid Build Coastguard Worker        kernel_w=st.integers(1, 7),
89*da0073e9SAndroid Build Coastguard Worker        stride_h=st.integers(1, 2),
90*da0073e9SAndroid Build Coastguard Worker        stride_w=st.integers(1, 2),
91*da0073e9SAndroid Build Coastguard Worker        pad_h=st.integers(0, 2),
92*da0073e9SAndroid Build Coastguard Worker        pad_w=st.integers(0, 2),
93*da0073e9SAndroid Build Coastguard Worker        dilation=st.integers(1, 2),
94*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
95*da0073e9SAndroid Build Coastguard Worker        format=st.sampled_from(
96*da0073e9SAndroid Build Coastguard Worker            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
97*da0073e9SAndroid Build Coastguard Worker        ),
98*da0073e9SAndroid Build Coastguard Worker    )
99*da0073e9SAndroid Build Coastguard Worker    def test_conv2d(
100*da0073e9SAndroid Build Coastguard Worker        self,
101*da0073e9SAndroid Build Coastguard Worker        batch_size,
102*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group,
103*da0073e9SAndroid Build Coastguard Worker        height,
104*da0073e9SAndroid Build Coastguard Worker        width,
105*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group,
106*da0073e9SAndroid Build Coastguard Worker        groups,
107*da0073e9SAndroid Build Coastguard Worker        kernel_h,
108*da0073e9SAndroid Build Coastguard Worker        kernel_w,
109*da0073e9SAndroid Build Coastguard Worker        stride_h,
110*da0073e9SAndroid Build Coastguard Worker        stride_w,
111*da0073e9SAndroid Build Coastguard Worker        pad_h,
112*da0073e9SAndroid Build Coastguard Worker        pad_w,
113*da0073e9SAndroid Build Coastguard Worker        dilation,
114*da0073e9SAndroid Build Coastguard Worker        use_bias,
115*da0073e9SAndroid Build Coastguard Worker        format,
116*da0073e9SAndroid Build Coastguard Worker    ):
117*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
118*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
119*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
120*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
121*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
122*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
123*da0073e9SAndroid Build Coastguard Worker        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
124*da0073e9SAndroid Build Coastguard Worker        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
127*da0073e9SAndroid Build Coastguard Worker        if format is not None:
128*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
129*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(
130*da0073e9SAndroid Build Coastguard Worker            (output_channels, input_channels_per_group, kernel_h, kernel_w)
131*da0073e9SAndroid Build Coastguard Worker        )
132*da0073e9SAndroid Build Coastguard Worker        bias = None
133*da0073e9SAndroid Build Coastguard Worker        if use_bias:
134*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(output_channels)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        ref_result = F.conv2d(
137*da0073e9SAndroid Build Coastguard Worker            input_data, weight, bias, strides, paddings, dilations, groups
138*da0073e9SAndroid Build Coastguard Worker        )
139*da0073e9SAndroid Build Coastguard Worker        packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
140*da0073e9SAndroid Build Coastguard Worker            weight, bias, strides, paddings, dilations, groups
141*da0073e9SAndroid Build Coastguard Worker        )
142*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(
143*da0073e9SAndroid Build Coastguard Worker            input_data, packed_weight_bias
144*da0073e9SAndroid Build Coastguard Worker        )
145*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    @given(
148*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(1, 3),
149*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group=st.integers(1, 32),
150*da0073e9SAndroid Build Coastguard Worker        height=st.integers(5, 64),
151*da0073e9SAndroid Build Coastguard Worker        width=st.integers(5, 64),
152*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group=st.integers(1, 32),
153*da0073e9SAndroid Build Coastguard Worker        groups=st.integers(1, 16),
154*da0073e9SAndroid Build Coastguard Worker        kernel_h=st.integers(1, 7),
155*da0073e9SAndroid Build Coastguard Worker        kernel_w=st.integers(1, 7),
156*da0073e9SAndroid Build Coastguard Worker        stride_h=st.integers(1, 2),
157*da0073e9SAndroid Build Coastguard Worker        stride_w=st.integers(1, 2),
158*da0073e9SAndroid Build Coastguard Worker        pad_h=st.integers(0, 2),
159*da0073e9SAndroid Build Coastguard Worker        pad_w=st.integers(0, 2),
160*da0073e9SAndroid Build Coastguard Worker        output_pad_h=st.integers(0, 2),
161*da0073e9SAndroid Build Coastguard Worker        output_pad_w=st.integers(0, 2),
162*da0073e9SAndroid Build Coastguard Worker        dilation=st.integers(1, 2),
163*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
164*da0073e9SAndroid Build Coastguard Worker        format=st.sampled_from(
165*da0073e9SAndroid Build Coastguard Worker            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
166*da0073e9SAndroid Build Coastguard Worker        ),
167*da0073e9SAndroid Build Coastguard Worker    )
168*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_transpose(
169*da0073e9SAndroid Build Coastguard Worker        self,
170*da0073e9SAndroid Build Coastguard Worker        batch_size,
171*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group,
172*da0073e9SAndroid Build Coastguard Worker        height,
173*da0073e9SAndroid Build Coastguard Worker        width,
174*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group,
175*da0073e9SAndroid Build Coastguard Worker        groups,
176*da0073e9SAndroid Build Coastguard Worker        kernel_h,
177*da0073e9SAndroid Build Coastguard Worker        kernel_w,
178*da0073e9SAndroid Build Coastguard Worker        stride_h,
179*da0073e9SAndroid Build Coastguard Worker        stride_w,
180*da0073e9SAndroid Build Coastguard Worker        pad_h,
181*da0073e9SAndroid Build Coastguard Worker        pad_w,
182*da0073e9SAndroid Build Coastguard Worker        output_pad_h,
183*da0073e9SAndroid Build Coastguard Worker        output_pad_w,
184*da0073e9SAndroid Build Coastguard Worker        dilation,
185*da0073e9SAndroid Build Coastguard Worker        use_bias,
186*da0073e9SAndroid Build Coastguard Worker        format,
187*da0073e9SAndroid Build Coastguard Worker    ):
188*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
189*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
190*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
191*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
192*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
193*da0073e9SAndroid Build Coastguard Worker        output_paddings = (output_pad_h, output_pad_w)
194*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
195*da0073e9SAndroid Build Coastguard Worker        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
196*da0073e9SAndroid Build Coastguard Worker        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
197*da0073e9SAndroid Build Coastguard Worker        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
198*da0073e9SAndroid Build Coastguard Worker        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
201*da0073e9SAndroid Build Coastguard Worker        if format is not None:
202*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
203*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(
204*da0073e9SAndroid Build Coastguard Worker            (input_channels, output_channels_per_group, kernel_h, kernel_w)
205*da0073e9SAndroid Build Coastguard Worker        )
206*da0073e9SAndroid Build Coastguard Worker        bias = None
207*da0073e9SAndroid Build Coastguard Worker        if use_bias:
208*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(output_channels)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        # Note that groups/dilation is in reverse order from conv2d
211*da0073e9SAndroid Build Coastguard Worker        ref_result = F.conv_transpose2d(
212*da0073e9SAndroid Build Coastguard Worker            input_data,
213*da0073e9SAndroid Build Coastguard Worker            weight,
214*da0073e9SAndroid Build Coastguard Worker            bias,
215*da0073e9SAndroid Build Coastguard Worker            strides,
216*da0073e9SAndroid Build Coastguard Worker            paddings,
217*da0073e9SAndroid Build Coastguard Worker            output_paddings,
218*da0073e9SAndroid Build Coastguard Worker            groups,
219*da0073e9SAndroid Build Coastguard Worker            dilation,
220*da0073e9SAndroid Build Coastguard Worker        )
221*da0073e9SAndroid Build Coastguard Worker        packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(
222*da0073e9SAndroid Build Coastguard Worker            weight, bias, strides, paddings, output_paddings, dilations, groups
223*da0073e9SAndroid Build Coastguard Worker        )
224*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(
225*da0073e9SAndroid Build Coastguard Worker            input_data, packed_weight_bias
226*da0073e9SAndroid Build Coastguard Worker        )
227*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
228*da0073e9SAndroid Build Coastguard Worker            ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3
229*da0073e9SAndroid Build Coastguard Worker        )
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(
233*da0073e9SAndroid Build Coastguard Worker    torch.backends.xnnpack.enabled,
234*da0073e9SAndroid Build Coastguard Worker    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
235*da0073e9SAndroid Build Coastguard Worker)
236*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
237*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
238*da0073e9SAndroid Build Coastguard Worker    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
239*da0073e9SAndroid Build Coastguard Worker)
240*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKSerDes(TestCase):
241*da0073e9SAndroid Build Coastguard Worker    @unittest.skip(
242*da0073e9SAndroid Build Coastguard Worker        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
243*da0073e9SAndroid Build Coastguard Worker    )
244*da0073e9SAndroid Build Coastguard Worker    @given(
245*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
246*da0073e9SAndroid Build Coastguard Worker        data_shape=hu.array_shapes(1, 3, 2, 64),
247*da0073e9SAndroid Build Coastguard Worker        weight_output_dim=st.integers(2, 64),
248*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
249*da0073e9SAndroid Build Coastguard Worker    )
250*da0073e9SAndroid Build Coastguard Worker    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
251*da0073e9SAndroid Build Coastguard Worker        class Linear(torch.nn.Module):
252*da0073e9SAndroid Build Coastguard Worker            def __init__(self, weight, bias=None):
253*da0073e9SAndroid Build Coastguard Worker                super().__init__()
254*da0073e9SAndroid Build Coastguard Worker                self.weight = weight
255*da0073e9SAndroid Build Coastguard Worker                self.bias = bias
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
258*da0073e9SAndroid Build Coastguard Worker                return F.linear(x, self.weight, self.bias)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        class LinearPrePacked(torch.nn.Module):
261*da0073e9SAndroid Build Coastguard Worker            def __init__(self, weight, bias=None):
262*da0073e9SAndroid Build Coastguard Worker                super().__init__()
263*da0073e9SAndroid Build Coastguard Worker                self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(
264*da0073e9SAndroid Build Coastguard Worker                    weight, bias
265*da0073e9SAndroid Build Coastguard Worker                )
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
268*da0073e9SAndroid Build Coastguard Worker                return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker        data_shape = [batch_size] + list(data_shape)
271*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand((weight_output_dim, data_shape[-1]))
272*da0073e9SAndroid Build Coastguard Worker        if use_bias:
273*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(weight_output_dim)
274*da0073e9SAndroid Build Coastguard Worker        else:
275*da0073e9SAndroid Build Coastguard Worker            bias = None
276*da0073e9SAndroid Build Coastguard Worker        scripted_linear = torch.jit.script(Linear(weight, bias))
277*da0073e9SAndroid Build Coastguard Worker        scripted_linear_clamp_prepacked = torch.jit.script(
278*da0073e9SAndroid Build Coastguard Worker            LinearPrePacked(weight, bias)
279*da0073e9SAndroid Build Coastguard Worker        )
280*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand(data_shape)
281*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_linear(input_data)
282*da0073e9SAndroid Build Coastguard Worker        output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
283*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
284*da0073e9SAndroid Build Coastguard Worker            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
285*da0073e9SAndroid Build Coastguard Worker        )
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        # Serialize the modules and then deserialize
288*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand(data_shape)
289*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
290*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_linear, buffer)
291*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
292*da0073e9SAndroid Build Coastguard Worker        deserialized_linear = torch.jit.load(buffer)
293*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
294*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_linear_clamp_prepacked, buffer)
295*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
296*da0073e9SAndroid Build Coastguard Worker        deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
297*da0073e9SAndroid Build Coastguard Worker        ref_result = deserialized_linear(input_data)
298*da0073e9SAndroid Build Coastguard Worker        output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
299*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
300*da0073e9SAndroid Build Coastguard Worker            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
301*da0073e9SAndroid Build Coastguard Worker        )
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    @given(
304*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
305*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group=st.integers(1, 32),
306*da0073e9SAndroid Build Coastguard Worker        height=st.integers(5, 64),
307*da0073e9SAndroid Build Coastguard Worker        width=st.integers(5, 64),
308*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group=st.integers(1, 32),
309*da0073e9SAndroid Build Coastguard Worker        groups=st.integers(1, 16),
310*da0073e9SAndroid Build Coastguard Worker        kernel_h=st.integers(1, 7),
311*da0073e9SAndroid Build Coastguard Worker        kernel_w=st.integers(1, 7),
312*da0073e9SAndroid Build Coastguard Worker        stride_h=st.integers(1, 2),
313*da0073e9SAndroid Build Coastguard Worker        stride_w=st.integers(1, 2),
314*da0073e9SAndroid Build Coastguard Worker        pad_h=st.integers(0, 2),
315*da0073e9SAndroid Build Coastguard Worker        pad_w=st.integers(0, 2),
316*da0073e9SAndroid Build Coastguard Worker        dilation=st.integers(1, 2),
317*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
318*da0073e9SAndroid Build Coastguard Worker        format=st.sampled_from(
319*da0073e9SAndroid Build Coastguard Worker            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
320*da0073e9SAndroid Build Coastguard Worker        ),
321*da0073e9SAndroid Build Coastguard Worker    )
322*da0073e9SAndroid Build Coastguard Worker    def test_conv2d(
323*da0073e9SAndroid Build Coastguard Worker        self,
324*da0073e9SAndroid Build Coastguard Worker        batch_size,
325*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group,
326*da0073e9SAndroid Build Coastguard Worker        height,
327*da0073e9SAndroid Build Coastguard Worker        width,
328*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group,
329*da0073e9SAndroid Build Coastguard Worker        groups,
330*da0073e9SAndroid Build Coastguard Worker        kernel_h,
331*da0073e9SAndroid Build Coastguard Worker        kernel_w,
332*da0073e9SAndroid Build Coastguard Worker        stride_h,
333*da0073e9SAndroid Build Coastguard Worker        stride_w,
334*da0073e9SAndroid Build Coastguard Worker        pad_h,
335*da0073e9SAndroid Build Coastguard Worker        pad_w,
336*da0073e9SAndroid Build Coastguard Worker        dilation,
337*da0073e9SAndroid Build Coastguard Worker        use_bias,
338*da0073e9SAndroid Build Coastguard Worker        format,
339*da0073e9SAndroid Build Coastguard Worker    ):
340*da0073e9SAndroid Build Coastguard Worker        class Conv2D(torch.nn.Module):
341*da0073e9SAndroid Build Coastguard Worker            def __init__(self, weight, bias, strides, paddings, dilations, groups):
342*da0073e9SAndroid Build Coastguard Worker                super().__init__()
343*da0073e9SAndroid Build Coastguard Worker                self.weight = weight
344*da0073e9SAndroid Build Coastguard Worker                self.bias = bias
345*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
346*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
347*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
348*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
351*da0073e9SAndroid Build Coastguard Worker                return F.conv2d(
352*da0073e9SAndroid Build Coastguard Worker                    x,
353*da0073e9SAndroid Build Coastguard Worker                    self.weight,
354*da0073e9SAndroid Build Coastguard Worker                    self.bias,
355*da0073e9SAndroid Build Coastguard Worker                    self.strides,
356*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
357*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
358*da0073e9SAndroid Build Coastguard Worker                    self.groups,
359*da0073e9SAndroid Build Coastguard Worker                )
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        class Conv2DPrePacked(torch.nn.Module):
362*da0073e9SAndroid Build Coastguard Worker            def __init__(self, weight, bias, strides, paddings, dilations, groups):
363*da0073e9SAndroid Build Coastguard Worker                super().__init__()
364*da0073e9SAndroid Build Coastguard Worker                self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
365*da0073e9SAndroid Build Coastguard Worker                    weight, bias, strides, paddings, dilations, groups
366*da0073e9SAndroid Build Coastguard Worker                )
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
369*da0073e9SAndroid Build Coastguard Worker                return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
372*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
373*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
374*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
375*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
376*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
377*da0073e9SAndroid Build Coastguard Worker        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
378*da0073e9SAndroid Build Coastguard Worker        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
381*da0073e9SAndroid Build Coastguard Worker        if format is not None:
382*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
383*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(
384*da0073e9SAndroid Build Coastguard Worker            (output_channels, input_channels_per_group, kernel_h, kernel_w)
385*da0073e9SAndroid Build Coastguard Worker        )
386*da0073e9SAndroid Build Coastguard Worker        bias = None
387*da0073e9SAndroid Build Coastguard Worker        if use_bias:
388*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(output_channels)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        scripted_conv2d = torch.jit.script(
391*da0073e9SAndroid Build Coastguard Worker            Conv2D(weight, bias, strides, paddings, dilations, groups)
392*da0073e9SAndroid Build Coastguard Worker        )
393*da0073e9SAndroid Build Coastguard Worker        scripted_conv2d_clamp_prepacked = torch.jit.script(
394*da0073e9SAndroid Build Coastguard Worker            Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups)
395*da0073e9SAndroid Build Coastguard Worker        )
396*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_conv2d(input_data)
397*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
398*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        # Serialize the modules and then deserialize
401*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
402*da0073e9SAndroid Build Coastguard Worker        if format is not None:
403*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
404*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
405*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_conv2d, buffer)
406*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
407*da0073e9SAndroid Build Coastguard Worker        deserialized_conv2d = torch.jit.load(buffer)
408*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
409*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
410*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
411*da0073e9SAndroid Build Coastguard Worker        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
412*da0073e9SAndroid Build Coastguard Worker        ref_result = deserialized_conv2d(input_data)
413*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
414*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    @given(
417*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
418*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group=st.integers(1, 32),
419*da0073e9SAndroid Build Coastguard Worker        height=st.integers(5, 64),
420*da0073e9SAndroid Build Coastguard Worker        width=st.integers(5, 64),
421*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group=st.integers(1, 32),
422*da0073e9SAndroid Build Coastguard Worker        groups=st.integers(1, 16),
423*da0073e9SAndroid Build Coastguard Worker        kernel_h=st.integers(1, 7),
424*da0073e9SAndroid Build Coastguard Worker        kernel_w=st.integers(1, 7),
425*da0073e9SAndroid Build Coastguard Worker        stride_h=st.integers(1, 2),
426*da0073e9SAndroid Build Coastguard Worker        stride_w=st.integers(1, 2),
427*da0073e9SAndroid Build Coastguard Worker        pad_h=st.integers(0, 2),
428*da0073e9SAndroid Build Coastguard Worker        pad_w=st.integers(0, 2),
429*da0073e9SAndroid Build Coastguard Worker        output_pad_h=st.integers(0, 2),
430*da0073e9SAndroid Build Coastguard Worker        output_pad_w=st.integers(0, 2),
431*da0073e9SAndroid Build Coastguard Worker        dilation=st.integers(1, 2),
432*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
433*da0073e9SAndroid Build Coastguard Worker        format=st.sampled_from(
434*da0073e9SAndroid Build Coastguard Worker            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
435*da0073e9SAndroid Build Coastguard Worker        ),
436*da0073e9SAndroid Build Coastguard Worker    )
437*da0073e9SAndroid Build Coastguard Worker    def test_conv2d_transpose(
438*da0073e9SAndroid Build Coastguard Worker        self,
439*da0073e9SAndroid Build Coastguard Worker        batch_size,
440*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group,
441*da0073e9SAndroid Build Coastguard Worker        height,
442*da0073e9SAndroid Build Coastguard Worker        width,
443*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group,
444*da0073e9SAndroid Build Coastguard Worker        groups,
445*da0073e9SAndroid Build Coastguard Worker        kernel_h,
446*da0073e9SAndroid Build Coastguard Worker        kernel_w,
447*da0073e9SAndroid Build Coastguard Worker        stride_h,
448*da0073e9SAndroid Build Coastguard Worker        stride_w,
449*da0073e9SAndroid Build Coastguard Worker        pad_h,
450*da0073e9SAndroid Build Coastguard Worker        pad_w,
451*da0073e9SAndroid Build Coastguard Worker        output_pad_h,
452*da0073e9SAndroid Build Coastguard Worker        output_pad_w,
453*da0073e9SAndroid Build Coastguard Worker        dilation,
454*da0073e9SAndroid Build Coastguard Worker        use_bias,
455*da0073e9SAndroid Build Coastguard Worker        format,
456*da0073e9SAndroid Build Coastguard Worker    ):
457*da0073e9SAndroid Build Coastguard Worker        class Conv2DT(torch.nn.Module):
458*da0073e9SAndroid Build Coastguard Worker            def __init__(
459*da0073e9SAndroid Build Coastguard Worker                self,
460*da0073e9SAndroid Build Coastguard Worker                weight,
461*da0073e9SAndroid Build Coastguard Worker                bias,
462*da0073e9SAndroid Build Coastguard Worker                strides,
463*da0073e9SAndroid Build Coastguard Worker                paddings,
464*da0073e9SAndroid Build Coastguard Worker                output_paddings,
465*da0073e9SAndroid Build Coastguard Worker                dilations,
466*da0073e9SAndroid Build Coastguard Worker                groups,
467*da0073e9SAndroid Build Coastguard Worker            ):
468*da0073e9SAndroid Build Coastguard Worker                super().__init__()
469*da0073e9SAndroid Build Coastguard Worker                self.weight = weight
470*da0073e9SAndroid Build Coastguard Worker                self.bias = bias
471*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
472*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
473*da0073e9SAndroid Build Coastguard Worker                self.output_paddings = output_paddings
474*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
475*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
478*da0073e9SAndroid Build Coastguard Worker                return F.conv_transpose2d(
479*da0073e9SAndroid Build Coastguard Worker                    x,
480*da0073e9SAndroid Build Coastguard Worker                    self.weight,
481*da0073e9SAndroid Build Coastguard Worker                    self.bias,
482*da0073e9SAndroid Build Coastguard Worker                    self.strides,
483*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
484*da0073e9SAndroid Build Coastguard Worker                    self.output_paddings,
485*da0073e9SAndroid Build Coastguard Worker                    self.groups,
486*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
487*da0073e9SAndroid Build Coastguard Worker                )
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        class Conv2DTPrePacked(torch.nn.Module):
490*da0073e9SAndroid Build Coastguard Worker            def __init__(
491*da0073e9SAndroid Build Coastguard Worker                self,
492*da0073e9SAndroid Build Coastguard Worker                weight,
493*da0073e9SAndroid Build Coastguard Worker                bias,
494*da0073e9SAndroid Build Coastguard Worker                strides,
495*da0073e9SAndroid Build Coastguard Worker                paddings,
496*da0073e9SAndroid Build Coastguard Worker                output_paddings,
497*da0073e9SAndroid Build Coastguard Worker                dilations,
498*da0073e9SAndroid Build Coastguard Worker                groups,
499*da0073e9SAndroid Build Coastguard Worker            ):
500*da0073e9SAndroid Build Coastguard Worker                super().__init__()
501*da0073e9SAndroid Build Coastguard Worker                self.packed_weight_bias = (
502*da0073e9SAndroid Build Coastguard Worker                    torch.ops.prepacked.conv2d_transpose_clamp_prepack(
503*da0073e9SAndroid Build Coastguard Worker                        weight,
504*da0073e9SAndroid Build Coastguard Worker                        bias,
505*da0073e9SAndroid Build Coastguard Worker                        strides,
506*da0073e9SAndroid Build Coastguard Worker                        paddings,
507*da0073e9SAndroid Build Coastguard Worker                        output_paddings,
508*da0073e9SAndroid Build Coastguard Worker                        dilations,
509*da0073e9SAndroid Build Coastguard Worker                        groups,
510*da0073e9SAndroid Build Coastguard Worker                    )
511*da0073e9SAndroid Build Coastguard Worker                )
512*da0073e9SAndroid Build Coastguard Worker
513*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
514*da0073e9SAndroid Build Coastguard Worker                return torch.ops.prepacked.conv2d_transpose_clamp_run(
515*da0073e9SAndroid Build Coastguard Worker                    x, self.packed_weight_bias
516*da0073e9SAndroid Build Coastguard Worker                )
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
519*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
520*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
521*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
522*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
523*da0073e9SAndroid Build Coastguard Worker        output_paddings = (output_pad_h, output_pad_w)
524*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
525*da0073e9SAndroid Build Coastguard Worker        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
526*da0073e9SAndroid Build Coastguard Worker        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
527*da0073e9SAndroid Build Coastguard Worker        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
528*da0073e9SAndroid Build Coastguard Worker        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
531*da0073e9SAndroid Build Coastguard Worker        if format is not None:
532*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
533*da0073e9SAndroid Build Coastguard Worker        weight = torch.rand(
534*da0073e9SAndroid Build Coastguard Worker            (input_channels, output_channels_per_group, kernel_h, kernel_w)
535*da0073e9SAndroid Build Coastguard Worker        )
536*da0073e9SAndroid Build Coastguard Worker        bias = None
537*da0073e9SAndroid Build Coastguard Worker        if use_bias:
538*da0073e9SAndroid Build Coastguard Worker            bias = torch.rand(output_channels)
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        scripted_conv2d = torch.jit.script(
541*da0073e9SAndroid Build Coastguard Worker            Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups)
542*da0073e9SAndroid Build Coastguard Worker        )
543*da0073e9SAndroid Build Coastguard Worker        scripted_conv2d_clamp_prepacked = torch.jit.script(
544*da0073e9SAndroid Build Coastguard Worker            Conv2DTPrePacked(
545*da0073e9SAndroid Build Coastguard Worker                weight, bias, strides, paddings, output_paddings, dilations, groups
546*da0073e9SAndroid Build Coastguard Worker            )
547*da0073e9SAndroid Build Coastguard Worker        )
548*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_conv2d(input_data)
549*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
550*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        # Serialize the modules and then deserialize
553*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
554*da0073e9SAndroid Build Coastguard Worker        if format is not None:
555*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
556*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
557*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_conv2d, buffer)
558*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
559*da0073e9SAndroid Build Coastguard Worker        deserialized_conv2d = torch.jit.load(buffer)
560*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
561*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
562*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
563*da0073e9SAndroid Build Coastguard Worker        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
564*da0073e9SAndroid Build Coastguard Worker        ref_result = deserialized_conv2d(input_data)
565*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
566*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker    @unittest.skip(
569*da0073e9SAndroid Build Coastguard Worker        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
570*da0073e9SAndroid Build Coastguard Worker    )
571*da0073e9SAndroid Build Coastguard Worker    @given(
572*da0073e9SAndroid Build Coastguard Worker        batch_size=st.integers(0, 3),
573*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group=st.integers(1, 32),
574*da0073e9SAndroid Build Coastguard Worker        height=st.integers(5, 64),
575*da0073e9SAndroid Build Coastguard Worker        width=st.integers(5, 64),
576*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group=st.integers(1, 32),
577*da0073e9SAndroid Build Coastguard Worker        groups=st.integers(1, 16),
578*da0073e9SAndroid Build Coastguard Worker        kernel_h=st.integers(1, 7),
579*da0073e9SAndroid Build Coastguard Worker        kernel_w=st.integers(1, 7),
580*da0073e9SAndroid Build Coastguard Worker        stride_h=st.integers(1, 2),
581*da0073e9SAndroid Build Coastguard Worker        stride_w=st.integers(1, 2),
582*da0073e9SAndroid Build Coastguard Worker        pad_h=st.integers(0, 2),
583*da0073e9SAndroid Build Coastguard Worker        pad_w=st.integers(0, 2),
584*da0073e9SAndroid Build Coastguard Worker        dilation=st.integers(1, 2),
585*da0073e9SAndroid Build Coastguard Worker        linear_weight_output_dim=st.integers(2, 64),
586*da0073e9SAndroid Build Coastguard Worker        use_bias=st.booleans(),
587*da0073e9SAndroid Build Coastguard Worker        format=st.sampled_from(
588*da0073e9SAndroid Build Coastguard Worker            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
589*da0073e9SAndroid Build Coastguard Worker        ),
590*da0073e9SAndroid Build Coastguard Worker    )
591*da0073e9SAndroid Build Coastguard Worker    def test_combined_model(
592*da0073e9SAndroid Build Coastguard Worker        self,
593*da0073e9SAndroid Build Coastguard Worker        batch_size,
594*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group,
595*da0073e9SAndroid Build Coastguard Worker        height,
596*da0073e9SAndroid Build Coastguard Worker        width,
597*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group,
598*da0073e9SAndroid Build Coastguard Worker        groups,
599*da0073e9SAndroid Build Coastguard Worker        kernel_h,
600*da0073e9SAndroid Build Coastguard Worker        kernel_w,
601*da0073e9SAndroid Build Coastguard Worker        stride_h,
602*da0073e9SAndroid Build Coastguard Worker        stride_w,
603*da0073e9SAndroid Build Coastguard Worker        pad_h,
604*da0073e9SAndroid Build Coastguard Worker        pad_w,
605*da0073e9SAndroid Build Coastguard Worker        dilation,
606*da0073e9SAndroid Build Coastguard Worker        linear_weight_output_dim,
607*da0073e9SAndroid Build Coastguard Worker        use_bias,
608*da0073e9SAndroid Build Coastguard Worker        format,
609*da0073e9SAndroid Build Coastguard Worker    ):
610*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
611*da0073e9SAndroid Build Coastguard Worker            def __init__(
612*da0073e9SAndroid Build Coastguard Worker                self,
613*da0073e9SAndroid Build Coastguard Worker                conv_weight,
614*da0073e9SAndroid Build Coastguard Worker                conv_bias,
615*da0073e9SAndroid Build Coastguard Worker                linear_weight,
616*da0073e9SAndroid Build Coastguard Worker                linear_bias,
617*da0073e9SAndroid Build Coastguard Worker                strides,
618*da0073e9SAndroid Build Coastguard Worker                paddings,
619*da0073e9SAndroid Build Coastguard Worker                dilations,
620*da0073e9SAndroid Build Coastguard Worker                groups,
621*da0073e9SAndroid Build Coastguard Worker            ):
622*da0073e9SAndroid Build Coastguard Worker                super().__init__()
623*da0073e9SAndroid Build Coastguard Worker                self.conv_weight = conv_weight
624*da0073e9SAndroid Build Coastguard Worker                self.conv_bias = conv_bias
625*da0073e9SAndroid Build Coastguard Worker                self.linear_weight = linear_weight
626*da0073e9SAndroid Build Coastguard Worker                self.linear_bias = linear_bias
627*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
628*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
629*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
630*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
633*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(
634*da0073e9SAndroid Build Coastguard Worker                    x,
635*da0073e9SAndroid Build Coastguard Worker                    self.conv_weight,
636*da0073e9SAndroid Build Coastguard Worker                    self.conv_bias,
637*da0073e9SAndroid Build Coastguard Worker                    self.strides,
638*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
639*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
640*da0073e9SAndroid Build Coastguard Worker                    self.groups,
641*da0073e9SAndroid Build Coastguard Worker                )
642*da0073e9SAndroid Build Coastguard Worker                o = o.permute([0, 2, 3, 1])
643*da0073e9SAndroid Build Coastguard Worker                o = F.linear(o, self.linear_weight, self.linear_bias)
644*da0073e9SAndroid Build Coastguard Worker                return F.relu(o)
645*da0073e9SAndroid Build Coastguard Worker
646*da0073e9SAndroid Build Coastguard Worker        class MPrePacked(torch.nn.Module):
647*da0073e9SAndroid Build Coastguard Worker            def __init__(
648*da0073e9SAndroid Build Coastguard Worker                self,
649*da0073e9SAndroid Build Coastguard Worker                conv_weight,
650*da0073e9SAndroid Build Coastguard Worker                conv_bias,
651*da0073e9SAndroid Build Coastguard Worker                linear_weight,
652*da0073e9SAndroid Build Coastguard Worker                linear_bias,
653*da0073e9SAndroid Build Coastguard Worker                strides,
654*da0073e9SAndroid Build Coastguard Worker                paddings,
655*da0073e9SAndroid Build Coastguard Worker                dilations,
656*da0073e9SAndroid Build Coastguard Worker                groups,
657*da0073e9SAndroid Build Coastguard Worker            ):
658*da0073e9SAndroid Build Coastguard Worker                super().__init__()
659*da0073e9SAndroid Build Coastguard Worker                self.conv2d_clamp_run_weight_bias = (
660*da0073e9SAndroid Build Coastguard Worker                    torch.ops.prepacked.conv2d_clamp_prepack(
661*da0073e9SAndroid Build Coastguard Worker                        conv_weight, conv_bias, strides, paddings, dilations, groups
662*da0073e9SAndroid Build Coastguard Worker                    )
663*da0073e9SAndroid Build Coastguard Worker                )
664*da0073e9SAndroid Build Coastguard Worker                self.linear_clamp_run_weight_bias = (
665*da0073e9SAndroid Build Coastguard Worker                    torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
666*da0073e9SAndroid Build Coastguard Worker                )
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
669*da0073e9SAndroid Build Coastguard Worker                o = torch.ops.prepacked.conv2d_clamp_run(
670*da0073e9SAndroid Build Coastguard Worker                    x, self.conv2d_clamp_run_weight_bias
671*da0073e9SAndroid Build Coastguard Worker                )
672*da0073e9SAndroid Build Coastguard Worker                o = o.permute([0, 2, 3, 1])
673*da0073e9SAndroid Build Coastguard Worker                o = torch.ops.prepacked.linear_clamp_run(
674*da0073e9SAndroid Build Coastguard Worker                    o, self.linear_clamp_run_weight_bias
675*da0073e9SAndroid Build Coastguard Worker                )
676*da0073e9SAndroid Build Coastguard Worker                return F.relu(o)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
679*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
680*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
681*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
682*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
683*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
684*da0073e9SAndroid Build Coastguard Worker        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
685*da0073e9SAndroid Build Coastguard Worker        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
688*da0073e9SAndroid Build Coastguard Worker        if format is not None:
689*da0073e9SAndroid Build Coastguard Worker            input_data = input_data.contiguous(memory_format=format)
690*da0073e9SAndroid Build Coastguard Worker        conv_weight = torch.rand(
691*da0073e9SAndroid Build Coastguard Worker            (output_channels, input_channels_per_group, kernel_h, kernel_w)
692*da0073e9SAndroid Build Coastguard Worker        )
693*da0073e9SAndroid Build Coastguard Worker        conv_bias = None
694*da0073e9SAndroid Build Coastguard Worker        if use_bias:
695*da0073e9SAndroid Build Coastguard Worker            conv_bias = torch.rand(output_channels)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        # This is done just to find the output shape of the result
698*da0073e9SAndroid Build Coastguard Worker        # so that the shape of weight for the following linear layer
699*da0073e9SAndroid Build Coastguard Worker        # can be determined.
700*da0073e9SAndroid Build Coastguard Worker        result = F.conv2d(
701*da0073e9SAndroid Build Coastguard Worker            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
702*da0073e9SAndroid Build Coastguard Worker        )
703*da0073e9SAndroid Build Coastguard Worker        linear_input_shape = result.shape[1]
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
706*da0073e9SAndroid Build Coastguard Worker        linear_bias = None
707*da0073e9SAndroid Build Coastguard Worker        if use_bias:
708*da0073e9SAndroid Build Coastguard Worker            linear_bias = torch.rand(linear_weight_output_dim)
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker        scripted_m = torch.jit.script(
711*da0073e9SAndroid Build Coastguard Worker            M(
712*da0073e9SAndroid Build Coastguard Worker                conv_weight,
713*da0073e9SAndroid Build Coastguard Worker                conv_bias,
714*da0073e9SAndroid Build Coastguard Worker                linear_weight,
715*da0073e9SAndroid Build Coastguard Worker                linear_bias,
716*da0073e9SAndroid Build Coastguard Worker                strides,
717*da0073e9SAndroid Build Coastguard Worker                paddings,
718*da0073e9SAndroid Build Coastguard Worker                dilations,
719*da0073e9SAndroid Build Coastguard Worker                groups,
720*da0073e9SAndroid Build Coastguard Worker            )
721*da0073e9SAndroid Build Coastguard Worker        )
722*da0073e9SAndroid Build Coastguard Worker        scripted_m_prepacked = torch.jit.script(
723*da0073e9SAndroid Build Coastguard Worker            MPrePacked(
724*da0073e9SAndroid Build Coastguard Worker                conv_weight,
725*da0073e9SAndroid Build Coastguard Worker                conv_bias,
726*da0073e9SAndroid Build Coastguard Worker                linear_weight,
727*da0073e9SAndroid Build Coastguard Worker                linear_bias,
728*da0073e9SAndroid Build Coastguard Worker                strides,
729*da0073e9SAndroid Build Coastguard Worker                paddings,
730*da0073e9SAndroid Build Coastguard Worker                dilations,
731*da0073e9SAndroid Build Coastguard Worker                groups,
732*da0073e9SAndroid Build Coastguard Worker            )
733*da0073e9SAndroid Build Coastguard Worker        )
734*da0073e9SAndroid Build Coastguard Worker        ref_result = scripted_m(input_data)
735*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = scripted_m_prepacked(input_data)
736*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker        # Serialize the modules and then deserialize
739*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
740*da0073e9SAndroid Build Coastguard Worker        input_data = input_data.contiguous(memory_format=torch.channels_last)
741*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
742*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_m, buffer)
743*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
744*da0073e9SAndroid Build Coastguard Worker        deserialized_m = torch.jit.load(buffer)
745*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
746*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(scripted_m_prepacked, buffer)
747*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
748*da0073e9SAndroid Build Coastguard Worker        deserialized_m_prepacked = torch.jit.load(buffer)
749*da0073e9SAndroid Build Coastguard Worker        ref_result = deserialized_m(input_data)
750*da0073e9SAndroid Build Coastguard Worker        xnnpack_result = deserialized_m_prepacked(input_data)
751*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker
754*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(
755*da0073e9SAndroid Build Coastguard Worker    torch.backends.xnnpack.enabled,
756*da0073e9SAndroid Build Coastguard Worker    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
757*da0073e9SAndroid Build Coastguard Worker)
758*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
759*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
760*da0073e9SAndroid Build Coastguard Worker    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
761*da0073e9SAndroid Build Coastguard Worker)
762*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKRewritePass(TestCase):
763*da0073e9SAndroid Build Coastguard Worker    @staticmethod
764*da0073e9SAndroid Build Coastguard Worker    def validate_transformed_module(
765*da0073e9SAndroid Build Coastguard Worker        # To please flake
766*da0073e9SAndroid Build Coastguard Worker        self,
767*da0073e9SAndroid Build Coastguard Worker        pattern_count_map,
768*da0073e9SAndroid Build Coastguard Worker        data_shape,
769*da0073e9SAndroid Build Coastguard Worker        prepack_removal=False,
770*da0073e9SAndroid Build Coastguard Worker        fuse_clamping_ops=False,
771*da0073e9SAndroid Build Coastguard Worker    ):
772*da0073e9SAndroid Build Coastguard Worker        input_data = torch.normal(1, 20, size=data_shape)
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker        for jit_method in ["script", "trace"]:
775*da0073e9SAndroid Build Coastguard Worker            module_instance = self
776*da0073e9SAndroid Build Coastguard Worker            if jit_method == "script":
777*da0073e9SAndroid Build Coastguard Worker                scripted_model = torch.jit.script(module_instance)
778*da0073e9SAndroid Build Coastguard Worker            else:
779*da0073e9SAndroid Build Coastguard Worker                scripted_model = torch.jit.trace(module_instance, input_data)
780*da0073e9SAndroid Build Coastguard Worker            scripted_model.eval()
781*da0073e9SAndroid Build Coastguard Worker            ref_result = scripted_model(input_data)
782*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
783*da0073e9SAndroid Build Coastguard Worker            if fuse_clamping_ops or prepack_removal:
784*da0073e9SAndroid Build Coastguard Worker                scripted_model._c = torch._C._freeze_module(scripted_model._c)
785*da0073e9SAndroid Build Coastguard Worker            if fuse_clamping_ops:
786*da0073e9SAndroid Build Coastguard Worker                torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
787*da0073e9SAndroid Build Coastguard Worker            if prepack_removal:
788*da0073e9SAndroid Build Coastguard Worker                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
791*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(scripted_model, buffer)
792*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
793*da0073e9SAndroid Build Coastguard Worker            deserialized_scripted_model = torch.jit.load(buffer)
794*da0073e9SAndroid Build Coastguard Worker            for pattern, v in pattern_count_map.items():
795*da0073e9SAndroid Build Coastguard Worker                if v == 0:
796*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
797*da0073e9SAndroid Build Coastguard Worker                elif v == -1:
798*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_not(pattern).run(
799*da0073e9SAndroid Build Coastguard Worker                        deserialized_scripted_model.graph
800*da0073e9SAndroid Build Coastguard Worker                    )
801*da0073e9SAndroid Build Coastguard Worker                else:
802*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_count(pattern, v, exactly=True).run(
803*da0073e9SAndroid Build Coastguard Worker                        deserialized_scripted_model.graph
804*da0073e9SAndroid Build Coastguard Worker                    )
805*da0073e9SAndroid Build Coastguard Worker            xnnpack_result = deserialized_scripted_model(input_data)
806*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker    def test_linear(self):
809*da0073e9SAndroid Build Coastguard Worker        data_shape = [2, 3, 32]
810*da0073e9SAndroid Build Coastguard Worker        weight_output_dim = 24
811*da0073e9SAndroid Build Coastguard Worker        weight_shape = (weight_output_dim, data_shape[-1])
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker        class Linear(torch.nn.Module):
814*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
815*da0073e9SAndroid Build Coastguard Worker                super().__init__()
816*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
817*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_shape), requires_grad=False
818*da0073e9SAndroid Build Coastguard Worker                )
819*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
820*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
821*da0073e9SAndroid Build Coastguard Worker                )
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
824*da0073e9SAndroid Build Coastguard Worker                return F.linear(x, self.weight, self.bias)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        class LinearNoBias(torch.nn.Module):
827*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
828*da0073e9SAndroid Build Coastguard Worker                super().__init__()
829*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
830*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_shape), requires_grad=False
831*da0073e9SAndroid Build Coastguard Worker                )
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
834*da0073e9SAndroid Build Coastguard Worker                return F.linear(x, self.weight, None)
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker        # Linear with bias pattern.
837*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
838*da0073e9SAndroid Build Coastguard Worker            "Tensor = prim::CallFunction": -1,
839*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": 1,
840*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
841*da0073e9SAndroid Build Coastguard Worker        }
842*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
843*da0073e9SAndroid Build Coastguard Worker            Linear(), pattern_count_map, data_shape
844*da0073e9SAndroid Build Coastguard Worker        )
845*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
846*da0073e9SAndroid Build Coastguard Worker            LinearNoBias(), pattern_count_map, data_shape
847*da0073e9SAndroid Build Coastguard Worker        )
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        # Conv params
850*da0073e9SAndroid Build Coastguard Worker        batch_size = 2
851*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group = 6
852*da0073e9SAndroid Build Coastguard Worker        height = 16
853*da0073e9SAndroid Build Coastguard Worker        width = 16
854*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group = 6
855*da0073e9SAndroid Build Coastguard Worker        groups = 4
856*da0073e9SAndroid Build Coastguard Worker        kernel_h = kernel_w = 3
857*da0073e9SAndroid Build Coastguard Worker        stride_h = stride_w = 1
858*da0073e9SAndroid Build Coastguard Worker        pad_h = pad_w = 1
859*da0073e9SAndroid Build Coastguard Worker        output_pad_h = output_pad_w = 0
860*da0073e9SAndroid Build Coastguard Worker        dilation = 1
861*da0073e9SAndroid Build Coastguard Worker        input_channels = input_channels_per_group * groups
862*da0073e9SAndroid Build Coastguard Worker        output_channels = output_channels_per_group * groups
863*da0073e9SAndroid Build Coastguard Worker        kernels = (kernel_h, kernel_w)
864*da0073e9SAndroid Build Coastguard Worker        strides = (stride_h, stride_w)
865*da0073e9SAndroid Build Coastguard Worker        paddings = (pad_h, pad_w)
866*da0073e9SAndroid Build Coastguard Worker        output_paddings = (output_pad_h, output_pad_w)
867*da0073e9SAndroid Build Coastguard Worker        dilations = (dilation, dilation)
868*da0073e9SAndroid Build Coastguard Worker        conv_weight_shape = (
869*da0073e9SAndroid Build Coastguard Worker            output_channels,
870*da0073e9SAndroid Build Coastguard Worker            input_channels_per_group,
871*da0073e9SAndroid Build Coastguard Worker            kernel_h,
872*da0073e9SAndroid Build Coastguard Worker            kernel_w,
873*da0073e9SAndroid Build Coastguard Worker        )
874*da0073e9SAndroid Build Coastguard Worker        conv_transpose_weight_shape = (
875*da0073e9SAndroid Build Coastguard Worker            input_channels,
876*da0073e9SAndroid Build Coastguard Worker            output_channels_per_group,
877*da0073e9SAndroid Build Coastguard Worker            kernel_h,
878*da0073e9SAndroid Build Coastguard Worker            kernel_w,
879*da0073e9SAndroid Build Coastguard Worker        )
880*da0073e9SAndroid Build Coastguard Worker        conv_bias_shape = output_channels
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker        class Conv2D(torch.nn.Module):
883*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
884*da0073e9SAndroid Build Coastguard Worker                super().__init__()
885*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
886*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_weight_shape), requires_grad=False
887*da0073e9SAndroid Build Coastguard Worker                )
888*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
889*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_bias_shape), requires_grad=False
890*da0073e9SAndroid Build Coastguard Worker                )
891*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
892*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
893*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
894*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
897*da0073e9SAndroid Build Coastguard Worker                return F.conv2d(
898*da0073e9SAndroid Build Coastguard Worker                    x,
899*da0073e9SAndroid Build Coastguard Worker                    self.weight,
900*da0073e9SAndroid Build Coastguard Worker                    self.bias,
901*da0073e9SAndroid Build Coastguard Worker                    self.strides,
902*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
903*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
904*da0073e9SAndroid Build Coastguard Worker                    self.groups,
905*da0073e9SAndroid Build Coastguard Worker                )
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        class Conv2DT(torch.nn.Module):
908*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
909*da0073e9SAndroid Build Coastguard Worker                super().__init__()
910*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
911*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_transpose_weight_shape), requires_grad=False
912*da0073e9SAndroid Build Coastguard Worker                )
913*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
914*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_bias_shape), requires_grad=False
915*da0073e9SAndroid Build Coastguard Worker                )
916*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
917*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
918*da0073e9SAndroid Build Coastguard Worker                self.output_paddings = output_paddings
919*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
920*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
923*da0073e9SAndroid Build Coastguard Worker                return F.conv_transpose2d(
924*da0073e9SAndroid Build Coastguard Worker                    x,
925*da0073e9SAndroid Build Coastguard Worker                    self.weight,
926*da0073e9SAndroid Build Coastguard Worker                    self.bias,
927*da0073e9SAndroid Build Coastguard Worker                    self.strides,
928*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
929*da0073e9SAndroid Build Coastguard Worker                    self.output_paddings,
930*da0073e9SAndroid Build Coastguard Worker                    self.groups,
931*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
932*da0073e9SAndroid Build Coastguard Worker                )
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker        data_shape = (batch_size, input_channels, height, width)
935*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
936*da0073e9SAndroid Build Coastguard Worker            "Tensor = aten::conv2d": -1,
937*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": 1,
938*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
939*da0073e9SAndroid Build Coastguard Worker        }
940*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
941*da0073e9SAndroid Build Coastguard Worker            Conv2D(), pattern_count_map, data_shape
942*da0073e9SAndroid Build Coastguard Worker        )
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        transpose_data_shape = (batch_size, input_channels, height, width)
945*da0073e9SAndroid Build Coastguard Worker        transpose_pattern_count_map = {
946*da0073e9SAndroid Build Coastguard Worker            "Tensor = aten::conv_transpose2d": -1,
947*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_transpose_clamp_prepack": 1,
948*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_transpose_clamp_run": 1,
949*da0073e9SAndroid Build Coastguard Worker        }
950*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
951*da0073e9SAndroid Build Coastguard Worker            Conv2DT(), transpose_pattern_count_map, data_shape
952*da0073e9SAndroid Build Coastguard Worker        )
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker        input_data = torch.rand((batch_size, input_channels, height, width))
955*da0073e9SAndroid Build Coastguard Worker        conv_weight = torch.rand(
956*da0073e9SAndroid Build Coastguard Worker            (output_channels, input_channels_per_group, kernel_h, kernel_w)
957*da0073e9SAndroid Build Coastguard Worker        )
958*da0073e9SAndroid Build Coastguard Worker        conv_bias = torch.rand(output_channels)
959*da0073e9SAndroid Build Coastguard Worker        result = F.conv2d(
960*da0073e9SAndroid Build Coastguard Worker            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
961*da0073e9SAndroid Build Coastguard Worker        )
962*da0073e9SAndroid Build Coastguard Worker        linear_input_shape = result.shape[1]
963*da0073e9SAndroid Build Coastguard Worker        linear_weight_shape = (weight_output_dim, linear_input_shape)
964*da0073e9SAndroid Build Coastguard Worker
965*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
966*da0073e9SAndroid Build Coastguard Worker            def __init__(self, activation_fn=F.relu):
967*da0073e9SAndroid Build Coastguard Worker                super().__init__()
968*da0073e9SAndroid Build Coastguard Worker                self.conv_weight = torch.nn.Parameter(
969*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_weight_shape), requires_grad=False
970*da0073e9SAndroid Build Coastguard Worker                )
971*da0073e9SAndroid Build Coastguard Worker                self.conv_bias = torch.nn.Parameter(
972*da0073e9SAndroid Build Coastguard Worker                    torch.rand(conv_bias_shape), requires_grad=False
973*da0073e9SAndroid Build Coastguard Worker                )
974*da0073e9SAndroid Build Coastguard Worker                self.linear_weight = torch.nn.Parameter(
975*da0073e9SAndroid Build Coastguard Worker                    torch.rand(linear_weight_shape), requires_grad=False
976*da0073e9SAndroid Build Coastguard Worker                )
977*da0073e9SAndroid Build Coastguard Worker                self.linear_bias = torch.nn.Parameter(
978*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
979*da0073e9SAndroid Build Coastguard Worker                )
980*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
981*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
982*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
983*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
984*da0073e9SAndroid Build Coastguard Worker                self.activation_fn = activation_fn
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
987*da0073e9SAndroid Build Coastguard Worker                o = F.conv2d(
988*da0073e9SAndroid Build Coastguard Worker                    x,
989*da0073e9SAndroid Build Coastguard Worker                    self.conv_weight,
990*da0073e9SAndroid Build Coastguard Worker                    self.conv_bias,
991*da0073e9SAndroid Build Coastguard Worker                    self.strides,
992*da0073e9SAndroid Build Coastguard Worker                    self.paddings,
993*da0073e9SAndroid Build Coastguard Worker                    self.dilations,
994*da0073e9SAndroid Build Coastguard Worker                    self.groups,
995*da0073e9SAndroid Build Coastguard Worker                )
996*da0073e9SAndroid Build Coastguard Worker                o = self.activation_fn(o)
997*da0073e9SAndroid Build Coastguard Worker                o = o.permute([0, 2, 3, 1])
998*da0073e9SAndroid Build Coastguard Worker                o = F.linear(o, self.linear_weight, self.linear_bias)
999*da0073e9SAndroid Build Coastguard Worker                return self.activation_fn(o)
1000*da0073e9SAndroid Build Coastguard Worker
1001*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1002*da0073e9SAndroid Build Coastguard Worker            "Tensor = aten::conv2d": -1,
1003*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": 1,
1004*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
1005*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": 1,
1006*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1007*da0073e9SAndroid Build Coastguard Worker        }
1008*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1009*da0073e9SAndroid Build Coastguard Worker            M(), pattern_count_map, data_shape
1010*da0073e9SAndroid Build Coastguard Worker        )
1011*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1012*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["Tensor = prim::CallFunction"] = -1
1013*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1014*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1015*da0073e9SAndroid Build Coastguard Worker            M(), pattern_count_map, data_shape, prepack_removal=True
1016*da0073e9SAndroid Build Coastguard Worker        )
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # Not inplace relu fusion test.
1019*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1020*da0073e9SAndroid Build Coastguard Worker            "aten::relu": 2,
1021*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": -1,
1022*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
1023*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1024*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1025*da0073e9SAndroid Build Coastguard Worker        }
1026*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1027*da0073e9SAndroid Build Coastguard Worker            M(), pattern_count_map, data_shape, prepack_removal=True
1028*da0073e9SAndroid Build Coastguard Worker        )
1029*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1030*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1031*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = -1
1032*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1033*da0073e9SAndroid Build Coastguard Worker            M(),
1034*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1035*da0073e9SAndroid Build Coastguard Worker            data_shape,
1036*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1037*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1038*da0073e9SAndroid Build Coastguard Worker        )
1039*da0073e9SAndroid Build Coastguard Worker
1040*da0073e9SAndroid Build Coastguard Worker        # Inplace relu fusion test.
1041*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1042*da0073e9SAndroid Build Coastguard Worker            "aten::relu": 2,
1043*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": -1,
1044*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
1045*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1046*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1047*da0073e9SAndroid Build Coastguard Worker        }
1048*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1049*da0073e9SAndroid Build Coastguard Worker            M(F.relu_), pattern_count_map, data_shape, prepack_removal=True
1050*da0073e9SAndroid Build Coastguard Worker        )
1051*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1052*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1053*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::relu"] = -1
1054*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1055*da0073e9SAndroid Build Coastguard Worker            M(F.relu_),
1056*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1057*da0073e9SAndroid Build Coastguard Worker            data_shape,
1058*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1059*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1060*da0073e9SAndroid Build Coastguard Worker        )
1061*da0073e9SAndroid Build Coastguard Worker
1062*da0073e9SAndroid Build Coastguard Worker        # Not inplace hardtanh fusion test.
1063*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1064*da0073e9SAndroid Build Coastguard Worker            "aten::hardtanh": 2,
1065*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": -1,
1066*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
1067*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1068*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1069*da0073e9SAndroid Build Coastguard Worker        }
1070*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1071*da0073e9SAndroid Build Coastguard Worker            M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True
1072*da0073e9SAndroid Build Coastguard Worker        )
1073*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1074*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1075*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh"] = -1
1076*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1077*da0073e9SAndroid Build Coastguard Worker            M(F.hardtanh),
1078*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1079*da0073e9SAndroid Build Coastguard Worker            data_shape,
1080*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1081*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1082*da0073e9SAndroid Build Coastguard Worker        )
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker        # Inplace hardtanh fusion test.
1085*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1086*da0073e9SAndroid Build Coastguard Worker            "aten::hardtanh_": 2,
1087*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_prepack": -1,
1088*da0073e9SAndroid Build Coastguard Worker            "prepacked::conv2d_clamp_run": 1,
1089*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1090*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1091*da0073e9SAndroid Build Coastguard Worker        }
1092*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1093*da0073e9SAndroid Build Coastguard Worker            M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True
1094*da0073e9SAndroid Build Coastguard Worker        )
1095*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1096*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1097*da0073e9SAndroid Build Coastguard Worker        pattern_count_map["aten::hardtanh_"] = -1
1098*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1099*da0073e9SAndroid Build Coastguard Worker            M(F.hardtanh_),
1100*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1101*da0073e9SAndroid Build Coastguard Worker            data_shape,
1102*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1103*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1104*da0073e9SAndroid Build Coastguard Worker        )
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        class MFusionAntiPattern(torch.nn.Module):
1107*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1108*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1109*da0073e9SAndroid Build Coastguard Worker                self.linear_weight = torch.nn.Parameter(
1110*da0073e9SAndroid Build Coastguard Worker                    torch.rand(linear_weight_shape), requires_grad=False
1111*da0073e9SAndroid Build Coastguard Worker                )
1112*da0073e9SAndroid Build Coastguard Worker                self.linear_bias = torch.nn.Parameter(
1113*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
1114*da0073e9SAndroid Build Coastguard Worker                )
1115*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
1116*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
1117*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
1118*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1121*da0073e9SAndroid Build Coastguard Worker                o = F.linear(x, self.linear_weight, self.linear_bias)
1122*da0073e9SAndroid Build Coastguard Worker                o = F.relu(o)
1123*da0073e9SAndroid Build Coastguard Worker                o = F.hardtanh(o)
1124*da0073e9SAndroid Build Coastguard Worker                return o
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        # Unfusable hardtanh.
1127*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1128*da0073e9SAndroid Build Coastguard Worker            "aten::hardtanh": 1,  # hardtanh cannot be.
1129*da0073e9SAndroid Build Coastguard Worker            "aten::relu": -1,  # relu is fused.
1130*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1131*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1132*da0073e9SAndroid Build Coastguard Worker        }
1133*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1134*da0073e9SAndroid Build Coastguard Worker            MFusionAntiPattern(),
1135*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1136*da0073e9SAndroid Build Coastguard Worker            (16, linear_weight_shape[1]),
1137*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1138*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1139*da0073e9SAndroid Build Coastguard Worker        )
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        class MFusionAntiPatternParamMinMax(torch.nn.Module):
1142*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1143*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1144*da0073e9SAndroid Build Coastguard Worker                self.linear_weight = torch.nn.Parameter(
1145*da0073e9SAndroid Build Coastguard Worker                    torch.rand(linear_weight_shape), requires_grad=False
1146*da0073e9SAndroid Build Coastguard Worker                )
1147*da0073e9SAndroid Build Coastguard Worker                self.linear_bias = torch.nn.Parameter(
1148*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
1149*da0073e9SAndroid Build Coastguard Worker                )
1150*da0073e9SAndroid Build Coastguard Worker                self.strides = strides
1151*da0073e9SAndroid Build Coastguard Worker                self.paddings = paddings
1152*da0073e9SAndroid Build Coastguard Worker                self.dilations = dilations
1153*da0073e9SAndroid Build Coastguard Worker                self.groups = groups
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1156*da0073e9SAndroid Build Coastguard Worker                min = x[0, 0]
1157*da0073e9SAndroid Build Coastguard Worker                max = min + 10
1158*da0073e9SAndroid Build Coastguard Worker                o = F.linear(x, self.linear_weight, self.linear_bias)
1159*da0073e9SAndroid Build Coastguard Worker                o = F.hardtanh(o, min, max)
1160*da0073e9SAndroid Build Coastguard Worker                return o
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        # Unfusable hardtanh.
1163*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1164*da0073e9SAndroid Build Coastguard Worker            "aten::hardtanh": 1,  # hardtanh cannot be.
1165*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": -1,
1166*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1167*da0073e9SAndroid Build Coastguard Worker        }
1168*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1169*da0073e9SAndroid Build Coastguard Worker            MFusionAntiPatternParamMinMax(),
1170*da0073e9SAndroid Build Coastguard Worker            pattern_count_map,
1171*da0073e9SAndroid Build Coastguard Worker            (16, linear_weight_shape[1]),
1172*da0073e9SAndroid Build Coastguard Worker            prepack_removal=True,
1173*da0073e9SAndroid Build Coastguard Worker            fuse_clamping_ops=True,
1174*da0073e9SAndroid Build Coastguard Worker        )
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker    def test_decomposed_linear(self):
1177*da0073e9SAndroid Build Coastguard Worker        data_shape = [2, 32]
1178*da0073e9SAndroid Build Coastguard Worker        weight_output_dim = 24
1179*da0073e9SAndroid Build Coastguard Worker        weight_shape = (weight_output_dim, data_shape[-1])
1180*da0073e9SAndroid Build Coastguard Worker
1181*da0073e9SAndroid Build Coastguard Worker        class DecomposedLinearAddmm(torch.nn.Module):
1182*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1183*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1184*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
1185*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_shape), requires_grad=False
1186*da0073e9SAndroid Build Coastguard Worker                )
1187*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
1188*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
1189*da0073e9SAndroid Build Coastguard Worker                )
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1192*da0073e9SAndroid Build Coastguard Worker                weight_t = self.weight.t()
1193*da0073e9SAndroid Build Coastguard Worker                return torch.addmm(self.bias, x, weight_t)
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker        class DecomposedLinearMatmulAdd(torch.nn.Module):
1196*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1197*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1198*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
1199*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_shape), requires_grad=False
1200*da0073e9SAndroid Build Coastguard Worker                )
1201*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
1202*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
1203*da0073e9SAndroid Build Coastguard Worker                )
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1206*da0073e9SAndroid Build Coastguard Worker                weight_t = self.weight.t()
1207*da0073e9SAndroid Build Coastguard Worker                y = torch.matmul(x, weight_t)
1208*da0073e9SAndroid Build Coastguard Worker                res = y.add_(self.bias)
1209*da0073e9SAndroid Build Coastguard Worker                return res
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        class DecomposedLinearMatmul(torch.nn.Module):
1212*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1213*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1214*da0073e9SAndroid Build Coastguard Worker                self.weight = torch.nn.Parameter(
1215*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_shape), requires_grad=False
1216*da0073e9SAndroid Build Coastguard Worker                )
1217*da0073e9SAndroid Build Coastguard Worker                self.bias = torch.nn.Parameter(
1218*da0073e9SAndroid Build Coastguard Worker                    torch.rand(weight_output_dim), requires_grad=False
1219*da0073e9SAndroid Build Coastguard Worker                )
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1222*da0073e9SAndroid Build Coastguard Worker                weight_t = self.weight.t()
1223*da0073e9SAndroid Build Coastguard Worker                res = torch.matmul(x, weight_t)
1224*da0073e9SAndroid Build Coastguard Worker                return res
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker        # Linear with bias pattern.
1227*da0073e9SAndroid Build Coastguard Worker        pattern_count_map = {
1228*da0073e9SAndroid Build Coastguard Worker            "Tensor = prim::CallFunction": -1,
1229*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_prepack": 1,
1230*da0073e9SAndroid Build Coastguard Worker            "prepacked::linear_clamp_run": 1,
1231*da0073e9SAndroid Build Coastguard Worker        }
1232*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1233*da0073e9SAndroid Build Coastguard Worker            DecomposedLinearAddmm(), pattern_count_map, data_shape
1234*da0073e9SAndroid Build Coastguard Worker        )
1235*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1236*da0073e9SAndroid Build Coastguard Worker            DecomposedLinearMatmulAdd(), pattern_count_map, data_shape
1237*da0073e9SAndroid Build Coastguard Worker        )
1238*da0073e9SAndroid Build Coastguard Worker        TestXNNPACKRewritePass.validate_transformed_module(
1239*da0073e9SAndroid Build Coastguard Worker            DecomposedLinearMatmul(), pattern_count_map, data_shape
1240*da0073e9SAndroid Build Coastguard Worker        )
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker@unittest.skipUnless(
1244*da0073e9SAndroid Build Coastguard Worker    torch.backends.xnnpack.enabled,
1245*da0073e9SAndroid Build Coastguard Worker    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
1246*da0073e9SAndroid Build Coastguard Worker)
1247*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(
1248*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TSAN,
1249*da0073e9SAndroid Build Coastguard Worker    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
1250*da0073e9SAndroid Build Coastguard Worker)
1251*da0073e9SAndroid Build Coastguard Workerclass TestXNNPACKConv1dTransformPass(TestCase):
1252*da0073e9SAndroid Build Coastguard Worker    @staticmethod
1253*da0073e9SAndroid Build Coastguard Worker    def validate_transform_conv1d_to_conv2d(
1254*da0073e9SAndroid Build Coastguard Worker        self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape
1255*da0073e9SAndroid Build Coastguard Worker    ):
1256*da0073e9SAndroid Build Coastguard Worker        input_data = torch.normal(1, 20, size=data_shape)
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker        for jit_method in ["script", "trace"]:
1259*da0073e9SAndroid Build Coastguard Worker            module_instance = self
1260*da0073e9SAndroid Build Coastguard Worker            if jit_method == "script":
1261*da0073e9SAndroid Build Coastguard Worker                scripted_model = torch.jit.script(module_instance)
1262*da0073e9SAndroid Build Coastguard Worker            else:
1263*da0073e9SAndroid Build Coastguard Worker                scripted_model = torch.jit.trace(module_instance, input_data)
1264*da0073e9SAndroid Build Coastguard Worker            scripted_model.eval()
1265*da0073e9SAndroid Build Coastguard Worker            ref_result = scripted_model(input_data)
1266*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
1267*da0073e9SAndroid Build Coastguard Worker            optimized_scripted_model = optimize_for_mobile(scripted_model)
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker            buffer = io.BytesIO()
1270*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(scripted_model, buffer)
1271*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
1272*da0073e9SAndroid Build Coastguard Worker            deserialized_scripted_model = torch.jit.load(buffer)
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker            for pattern, v in pattern_count_transformed_map.items():
1275*da0073e9SAndroid Build Coastguard Worker                if v == 0:
1276*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
1277*da0073e9SAndroid Build Coastguard Worker                elif v == -1:
1278*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_not(pattern).run(
1279*da0073e9SAndroid Build Coastguard Worker                        deserialized_scripted_model.graph
1280*da0073e9SAndroid Build Coastguard Worker                    )
1281*da0073e9SAndroid Build Coastguard Worker                else:
1282*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_count(pattern, v, exactly=True).run(
1283*da0073e9SAndroid Build Coastguard Worker                        deserialized_scripted_model.graph
1284*da0073e9SAndroid Build Coastguard Worker                    )
1285*da0073e9SAndroid Build Coastguard Worker            transformed_result = deserialized_scripted_model(input_data)
1286*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(
1287*da0073e9SAndroid Build Coastguard Worker                ref_result, transformed_result, rtol=1e-2, atol=1e-3
1288*da0073e9SAndroid Build Coastguard Worker            )
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker            optimized_buffer = io.BytesIO()
1291*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(optimized_scripted_model, optimized_buffer)
1292*da0073e9SAndroid Build Coastguard Worker            optimized_buffer.seek(0)
1293*da0073e9SAndroid Build Coastguard Worker            deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
1294*da0073e9SAndroid Build Coastguard Worker
1295*da0073e9SAndroid Build Coastguard Worker            for pattern, v in pattern_count_optimized_map.items():
1296*da0073e9SAndroid Build Coastguard Worker                if v == 0:
1297*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check(pattern).run(
1298*da0073e9SAndroid Build Coastguard Worker                        deserialized_optimized_scripted_model.graph
1299*da0073e9SAndroid Build Coastguard Worker                    )
1300*da0073e9SAndroid Build Coastguard Worker                elif v == -1:
1301*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_not(pattern).run(
1302*da0073e9SAndroid Build Coastguard Worker                        deserialized_optimized_scripted_model.graph
1303*da0073e9SAndroid Build Coastguard Worker                    )
1304*da0073e9SAndroid Build Coastguard Worker                else:
1305*da0073e9SAndroid Build Coastguard Worker                    FileCheck().check_count(pattern, v, exactly=True).run(
1306*da0073e9SAndroid Build Coastguard Worker                        deserialized_optimized_scripted_model.graph
1307*da0073e9SAndroid Build Coastguard Worker                    )
1308*da0073e9SAndroid Build Coastguard Worker            xnnpack_result = deserialized_optimized_scripted_model(input_data)
1309*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_FBCODE, "T137513244")
1312*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_basic(self):
1313*da0073e9SAndroid Build Coastguard Worker        batch_size_list = range(1, 3)
1314*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group_list = range(10, 12)
1315*da0073e9SAndroid Build Coastguard Worker        width_list = range(10, 12)
1316*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group_list = range(10, 12)
1317*da0073e9SAndroid Build Coastguard Worker        groups_list = range(1, 3)
1318*da0073e9SAndroid Build Coastguard Worker        kernel_list = range(1, 4)
1319*da0073e9SAndroid Build Coastguard Worker        stride_list = range(1, 3)
1320*da0073e9SAndroid Build Coastguard Worker        padding_list = range(0, 3)
1321*da0073e9SAndroid Build Coastguard Worker        dilation_list = range(1, 3)
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        for hparams in itertools.product(
1324*da0073e9SAndroid Build Coastguard Worker            batch_size_list,
1325*da0073e9SAndroid Build Coastguard Worker            input_channels_per_group_list,
1326*da0073e9SAndroid Build Coastguard Worker            width_list,
1327*da0073e9SAndroid Build Coastguard Worker            output_channels_per_group_list,
1328*da0073e9SAndroid Build Coastguard Worker            groups_list,
1329*da0073e9SAndroid Build Coastguard Worker            kernel_list,
1330*da0073e9SAndroid Build Coastguard Worker            stride_list,
1331*da0073e9SAndroid Build Coastguard Worker            padding_list,
1332*da0073e9SAndroid Build Coastguard Worker            dilation_list,
1333*da0073e9SAndroid Build Coastguard Worker        ):
1334*da0073e9SAndroid Build Coastguard Worker            (
1335*da0073e9SAndroid Build Coastguard Worker                batch_size,
1336*da0073e9SAndroid Build Coastguard Worker                input_channels_per_group,
1337*da0073e9SAndroid Build Coastguard Worker                width,
1338*da0073e9SAndroid Build Coastguard Worker                output_channels_per_group,
1339*da0073e9SAndroid Build Coastguard Worker                groups,
1340*da0073e9SAndroid Build Coastguard Worker                kernel,
1341*da0073e9SAndroid Build Coastguard Worker                stride,
1342*da0073e9SAndroid Build Coastguard Worker                padding,
1343*da0073e9SAndroid Build Coastguard Worker                dilation,
1344*da0073e9SAndroid Build Coastguard Worker            ) = hparams
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker            input_channels = input_channels_per_group * groups
1347*da0073e9SAndroid Build Coastguard Worker            output_channels = output_channels_per_group * groups
1348*da0073e9SAndroid Build Coastguard Worker            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1349*da0073e9SAndroid Build Coastguard Worker            conv_bias_shape = output_channels
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker            class Conv1D(torch.nn.Module):
1352*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
1353*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
1354*da0073e9SAndroid Build Coastguard Worker                    self.weight = torch.nn.Parameter(
1355*da0073e9SAndroid Build Coastguard Worker                        torch.rand(conv_weight_shape), requires_grad=False
1356*da0073e9SAndroid Build Coastguard Worker                    )
1357*da0073e9SAndroid Build Coastguard Worker                    self.bias = torch.nn.Parameter(
1358*da0073e9SAndroid Build Coastguard Worker                        torch.rand(conv_bias_shape), requires_grad=False
1359*da0073e9SAndroid Build Coastguard Worker                    )
1360*da0073e9SAndroid Build Coastguard Worker                    self.stride = stride
1361*da0073e9SAndroid Build Coastguard Worker                    self.padding = padding
1362*da0073e9SAndroid Build Coastguard Worker                    self.dilation = dilation
1363*da0073e9SAndroid Build Coastguard Worker                    self.groups = groups
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
1366*da0073e9SAndroid Build Coastguard Worker                    return F.conv1d(
1367*da0073e9SAndroid Build Coastguard Worker                        x,
1368*da0073e9SAndroid Build Coastguard Worker                        self.weight,
1369*da0073e9SAndroid Build Coastguard Worker                        self.bias,
1370*da0073e9SAndroid Build Coastguard Worker                        self.stride,
1371*da0073e9SAndroid Build Coastguard Worker                        self.padding,
1372*da0073e9SAndroid Build Coastguard Worker                        self.dilation,
1373*da0073e9SAndroid Build Coastguard Worker                        self.groups,
1374*da0073e9SAndroid Build Coastguard Worker                    )
1375*da0073e9SAndroid Build Coastguard Worker
1376*da0073e9SAndroid Build Coastguard Worker            data_shape = (batch_size, input_channels, width)
1377*da0073e9SAndroid Build Coastguard Worker            pattern_count_transformed_map = {
1378*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv1d": -1,
1379*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv2d": 1,
1380*da0073e9SAndroid Build Coastguard Worker            }
1381*da0073e9SAndroid Build Coastguard Worker            pattern_count_optimized_map = {
1382*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv1d": -1,
1383*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv2d": -1,
1384*da0073e9SAndroid Build Coastguard Worker                "prepacked::conv2d_clamp_prepack": -1,
1385*da0073e9SAndroid Build Coastguard Worker                "prepacked::conv2d_clamp_run": 1,
1386*da0073e9SAndroid Build Coastguard Worker            }
1387*da0073e9SAndroid Build Coastguard Worker
1388*da0073e9SAndroid Build Coastguard Worker            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1389*da0073e9SAndroid Build Coastguard Worker                Conv1D(),
1390*da0073e9SAndroid Build Coastguard Worker                pattern_count_transformed_map,
1391*da0073e9SAndroid Build Coastguard Worker                pattern_count_optimized_map,
1392*da0073e9SAndroid Build Coastguard Worker                data_shape,
1393*da0073e9SAndroid Build Coastguard Worker            )
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/46066
1396*da0073e9SAndroid Build Coastguard Worker    @slowTest
1397*da0073e9SAndroid Build Coastguard Worker    def test_conv1d_with_relu_fc(self):
1398*da0073e9SAndroid Build Coastguard Worker        batch_size_list = range(1, 3)
1399*da0073e9SAndroid Build Coastguard Worker        input_channels_per_group_list = range(10, 12)
1400*da0073e9SAndroid Build Coastguard Worker        width_list = range(10, 12)
1401*da0073e9SAndroid Build Coastguard Worker        output_channels_per_group_list = range(10, 12)
1402*da0073e9SAndroid Build Coastguard Worker        groups_list = range(1, 3)
1403*da0073e9SAndroid Build Coastguard Worker        kernel_list = range(1, 4)
1404*da0073e9SAndroid Build Coastguard Worker        stride_list = range(1, 3)
1405*da0073e9SAndroid Build Coastguard Worker        padding_list = range(0, 3)
1406*da0073e9SAndroid Build Coastguard Worker        dilation_list = range(1, 3)
1407*da0073e9SAndroid Build Coastguard Worker        output_features_list = range(1, 3)
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker        for hparams in itertools.product(
1410*da0073e9SAndroid Build Coastguard Worker            batch_size_list,
1411*da0073e9SAndroid Build Coastguard Worker            input_channels_per_group_list,
1412*da0073e9SAndroid Build Coastguard Worker            width_list,
1413*da0073e9SAndroid Build Coastguard Worker            output_channels_per_group_list,
1414*da0073e9SAndroid Build Coastguard Worker            groups_list,
1415*da0073e9SAndroid Build Coastguard Worker            kernel_list,
1416*da0073e9SAndroid Build Coastguard Worker            stride_list,
1417*da0073e9SAndroid Build Coastguard Worker            padding_list,
1418*da0073e9SAndroid Build Coastguard Worker            dilation_list,
1419*da0073e9SAndroid Build Coastguard Worker            output_features_list,
1420*da0073e9SAndroid Build Coastguard Worker        ):
1421*da0073e9SAndroid Build Coastguard Worker            (
1422*da0073e9SAndroid Build Coastguard Worker                batch_size,
1423*da0073e9SAndroid Build Coastguard Worker                input_channels_per_group,
1424*da0073e9SAndroid Build Coastguard Worker                width,
1425*da0073e9SAndroid Build Coastguard Worker                output_channels_per_group,
1426*da0073e9SAndroid Build Coastguard Worker                groups,
1427*da0073e9SAndroid Build Coastguard Worker                kernel,
1428*da0073e9SAndroid Build Coastguard Worker                stride,
1429*da0073e9SAndroid Build Coastguard Worker                padding,
1430*da0073e9SAndroid Build Coastguard Worker                dilation,
1431*da0073e9SAndroid Build Coastguard Worker                output_features,
1432*da0073e9SAndroid Build Coastguard Worker            ) = hparams
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker            input_channels = input_channels_per_group * groups
1435*da0073e9SAndroid Build Coastguard Worker            output_channels = output_channels_per_group * groups
1436*da0073e9SAndroid Build Coastguard Worker            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1437*da0073e9SAndroid Build Coastguard Worker            conv_bias_shape = output_channels
1438*da0073e9SAndroid Build Coastguard Worker            conv_output_width = (
1439*da0073e9SAndroid Build Coastguard Worker                int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1440*da0073e9SAndroid Build Coastguard Worker            )
1441*da0073e9SAndroid Build Coastguard Worker            fc_weight_shape = (output_features, output_channels * conv_output_width)
1442*da0073e9SAndroid Build Coastguard Worker            fc_bias_shape = output_features
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker            class Net(torch.nn.Module):
1445*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
1446*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
1447*da0073e9SAndroid Build Coastguard Worker                    self.conv_weight = torch.nn.Parameter(
1448*da0073e9SAndroid Build Coastguard Worker                        torch.rand(conv_weight_shape), requires_grad=False
1449*da0073e9SAndroid Build Coastguard Worker                    )
1450*da0073e9SAndroid Build Coastguard Worker                    self.conv_bias = torch.nn.Parameter(
1451*da0073e9SAndroid Build Coastguard Worker                        torch.rand(conv_bias_shape), requires_grad=False
1452*da0073e9SAndroid Build Coastguard Worker                    )
1453*da0073e9SAndroid Build Coastguard Worker                    self.stride = stride
1454*da0073e9SAndroid Build Coastguard Worker                    self.padding = padding
1455*da0073e9SAndroid Build Coastguard Worker                    self.dilation = dilation
1456*da0073e9SAndroid Build Coastguard Worker                    self.groups = groups
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker                    self.fc_weight = torch.nn.Parameter(
1459*da0073e9SAndroid Build Coastguard Worker                        torch.rand(fc_weight_shape), requires_grad=False
1460*da0073e9SAndroid Build Coastguard Worker                    )
1461*da0073e9SAndroid Build Coastguard Worker                    self.fc_bias = torch.nn.Parameter(
1462*da0073e9SAndroid Build Coastguard Worker                        torch.rand(fc_bias_shape), requires_grad=False
1463*da0073e9SAndroid Build Coastguard Worker                    )
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
1466*da0073e9SAndroid Build Coastguard Worker                    x = F.conv1d(
1467*da0073e9SAndroid Build Coastguard Worker                        x,
1468*da0073e9SAndroid Build Coastguard Worker                        self.conv_weight,
1469*da0073e9SAndroid Build Coastguard Worker                        self.conv_bias,
1470*da0073e9SAndroid Build Coastguard Worker                        self.stride,
1471*da0073e9SAndroid Build Coastguard Worker                        self.padding,
1472*da0073e9SAndroid Build Coastguard Worker                        self.dilation,
1473*da0073e9SAndroid Build Coastguard Worker                        self.groups,
1474*da0073e9SAndroid Build Coastguard Worker                    )
1475*da0073e9SAndroid Build Coastguard Worker                    x = F.relu(x)
1476*da0073e9SAndroid Build Coastguard Worker                    x = x.view(x.size(0), -1)
1477*da0073e9SAndroid Build Coastguard Worker                    x = F.linear(x, self.fc_weight, self.fc_bias)
1478*da0073e9SAndroid Build Coastguard Worker                    return x
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker            data_shape = (batch_size, input_channels, width)
1481*da0073e9SAndroid Build Coastguard Worker            pattern_count_transformed_map = {
1482*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv1d": -1,
1483*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv2d": 1,
1484*da0073e9SAndroid Build Coastguard Worker            }
1485*da0073e9SAndroid Build Coastguard Worker            pattern_count_optimized_map = {
1486*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv1d": -1,
1487*da0073e9SAndroid Build Coastguard Worker                "Tensor = aten::conv2d": -1,
1488*da0073e9SAndroid Build Coastguard Worker                "prepacked::conv2d_clamp_prepack": -1,
1489*da0073e9SAndroid Build Coastguard Worker                "prepacked::conv2d_clamp_run": 1,
1490*da0073e9SAndroid Build Coastguard Worker            }
1491*da0073e9SAndroid Build Coastguard Worker            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1492*da0073e9SAndroid Build Coastguard Worker                Net(),
1493*da0073e9SAndroid Build Coastguard Worker                pattern_count_transformed_map,
1494*da0073e9SAndroid Build Coastguard Worker                pattern_count_optimized_map,
1495*da0073e9SAndroid Build Coastguard Worker                data_shape,
1496*da0073e9SAndroid Build Coastguard Worker            )
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1500*da0073e9SAndroid Build Coastguard Worker    run_tests()
1501