xref: /aosp_15_r20/external/pytorch/test/test_xnnpack_integration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import io
4import itertools
5import unittest
6
7from hypothesis import assume, given, strategies as st
8
9import torch
10import torch.backends.xnnpack
11import torch.testing._internal.hypothesis_utils as hu
12from torch.nn import functional as F
13from torch.testing import FileCheck
14from torch.testing._internal.common_utils import (
15    IS_FBCODE,
16    run_tests,
17    slowTest,
18    TEST_WITH_TSAN,
19    TestCase,
20)
21from torch.utils.mobile_optimizer import optimize_for_mobile
22
23
24@unittest.skipUnless(
25    torch.backends.xnnpack.enabled,
26    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
27)
28@unittest.skipIf(
29    TEST_WITH_TSAN,
30    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
31)
32class TestXNNPACKOps(TestCase):
33    @unittest.skip(
34        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
35    )
36    @given(
37        batch_size=st.integers(0, 3),
38        data_shape=hu.array_shapes(1, 3, 2, 64),
39        weight_output_dim=st.integers(2, 64),
40        use_bias=st.booleans(),
41    )
42    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
43        data_shape = [batch_size] + list(data_shape)
44        input_data = torch.rand(data_shape)
45        weight = torch.rand((weight_output_dim, data_shape[-1]))
46        if use_bias:
47            bias = torch.rand(weight_output_dim)
48        else:
49            bias = None
50        ref_result = F.linear(input_data, weight, bias)
51        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
52        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
53            input_data, packed_weight_bias
54        )
55        torch.testing.assert_close(
56            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
57        )
58
59    @given(
60        input_size=st.integers(2, 32),
61        weight_output_dim=st.integers(2, 64),
62        use_bias=st.booleans(),
63    )
64    def test_linear_1d_input(self, input_size, weight_output_dim, use_bias):
65        input_data = torch.rand(input_size)
66        weight = torch.rand((weight_output_dim, input_data.shape[-1]))
67        if use_bias:
68            bias = torch.rand(weight_output_dim)
69        else:
70            bias = None
71        ref_result = F.linear(input_data, weight, bias)
72        packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias)
73        output_linearprepacked = torch.ops.prepacked.linear_clamp_run(
74            input_data, packed_weight_bias
75        )
76        torch.testing.assert_close(
77            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
78        )
79
80    @given(
81        batch_size=st.integers(0, 3),
82        input_channels_per_group=st.integers(1, 32),
83        height=st.integers(5, 64),
84        width=st.integers(5, 64),
85        output_channels_per_group=st.integers(1, 32),
86        groups=st.integers(1, 16),
87        kernel_h=st.integers(1, 7),
88        kernel_w=st.integers(1, 7),
89        stride_h=st.integers(1, 2),
90        stride_w=st.integers(1, 2),
91        pad_h=st.integers(0, 2),
92        pad_w=st.integers(0, 2),
93        dilation=st.integers(1, 2),
94        use_bias=st.booleans(),
95        format=st.sampled_from(
96            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
97        ),
98    )
99    def test_conv2d(
100        self,
101        batch_size,
102        input_channels_per_group,
103        height,
104        width,
105        output_channels_per_group,
106        groups,
107        kernel_h,
108        kernel_w,
109        stride_h,
110        stride_w,
111        pad_h,
112        pad_w,
113        dilation,
114        use_bias,
115        format,
116    ):
117        input_channels = input_channels_per_group * groups
118        output_channels = output_channels_per_group * groups
119        kernels = (kernel_h, kernel_w)
120        strides = (stride_h, stride_w)
121        paddings = (pad_h, pad_w)
122        dilations = (dilation, dilation)
123        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
124        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
125
126        input_data = torch.rand((batch_size, input_channels, height, width))
127        if format is not None:
128            input_data = input_data.contiguous(memory_format=format)
129        weight = torch.rand(
130            (output_channels, input_channels_per_group, kernel_h, kernel_w)
131        )
132        bias = None
133        if use_bias:
134            bias = torch.rand(output_channels)
135
136        ref_result = F.conv2d(
137            input_data, weight, bias, strides, paddings, dilations, groups
138        )
139        packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
140            weight, bias, strides, paddings, dilations, groups
141        )
142        xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(
143            input_data, packed_weight_bias
144        )
145        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
146
147    @given(
148        batch_size=st.integers(1, 3),
149        input_channels_per_group=st.integers(1, 32),
150        height=st.integers(5, 64),
151        width=st.integers(5, 64),
152        output_channels_per_group=st.integers(1, 32),
153        groups=st.integers(1, 16),
154        kernel_h=st.integers(1, 7),
155        kernel_w=st.integers(1, 7),
156        stride_h=st.integers(1, 2),
157        stride_w=st.integers(1, 2),
158        pad_h=st.integers(0, 2),
159        pad_w=st.integers(0, 2),
160        output_pad_h=st.integers(0, 2),
161        output_pad_w=st.integers(0, 2),
162        dilation=st.integers(1, 2),
163        use_bias=st.booleans(),
164        format=st.sampled_from(
165            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
166        ),
167    )
168    def test_conv2d_transpose(
169        self,
170        batch_size,
171        input_channels_per_group,
172        height,
173        width,
174        output_channels_per_group,
175        groups,
176        kernel_h,
177        kernel_w,
178        stride_h,
179        stride_w,
180        pad_h,
181        pad_w,
182        output_pad_h,
183        output_pad_w,
184        dilation,
185        use_bias,
186        format,
187    ):
188        input_channels = input_channels_per_group * groups
189        output_channels = output_channels_per_group * groups
190        kernels = (kernel_h, kernel_w)
191        strides = (stride_h, stride_w)
192        paddings = (pad_h, pad_w)
193        output_paddings = (output_pad_h, output_pad_w)
194        dilations = (dilation, dilation)
195        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
196        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
197        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
198        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
199
200        input_data = torch.rand((batch_size, input_channels, height, width))
201        if format is not None:
202            input_data = input_data.contiguous(memory_format=format)
203        weight = torch.rand(
204            (input_channels, output_channels_per_group, kernel_h, kernel_w)
205        )
206        bias = None
207        if use_bias:
208            bias = torch.rand(output_channels)
209
210        # Note that groups/dilation is in reverse order from conv2d
211        ref_result = F.conv_transpose2d(
212            input_data,
213            weight,
214            bias,
215            strides,
216            paddings,
217            output_paddings,
218            groups,
219            dilation,
220        )
221        packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(
222            weight, bias, strides, paddings, output_paddings, dilations, groups
223        )
224        xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(
225            input_data, packed_weight_bias
226        )
227        torch.testing.assert_close(
228            ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3
229        )
230
231
232@unittest.skipUnless(
233    torch.backends.xnnpack.enabled,
234    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
235)
236@unittest.skipIf(
237    TEST_WITH_TSAN,
238    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
239)
240class TestXNNPACKSerDes(TestCase):
241    @unittest.skip(
242        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
243    )
244    @given(
245        batch_size=st.integers(0, 3),
246        data_shape=hu.array_shapes(1, 3, 2, 64),
247        weight_output_dim=st.integers(2, 64),
248        use_bias=st.booleans(),
249    )
250    def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias):
251        class Linear(torch.nn.Module):
252            def __init__(self, weight, bias=None):
253                super().__init__()
254                self.weight = weight
255                self.bias = bias
256
257            def forward(self, x):
258                return F.linear(x, self.weight, self.bias)
259
260        class LinearPrePacked(torch.nn.Module):
261            def __init__(self, weight, bias=None):
262                super().__init__()
263                self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(
264                    weight, bias
265                )
266
267            def forward(self, x):
268                return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias)
269
270        data_shape = [batch_size] + list(data_shape)
271        weight = torch.rand((weight_output_dim, data_shape[-1]))
272        if use_bias:
273            bias = torch.rand(weight_output_dim)
274        else:
275            bias = None
276        scripted_linear = torch.jit.script(Linear(weight, bias))
277        scripted_linear_clamp_prepacked = torch.jit.script(
278            LinearPrePacked(weight, bias)
279        )
280        input_data = torch.rand(data_shape)
281        ref_result = scripted_linear(input_data)
282        output_linearprepacked = scripted_linear_clamp_prepacked(input_data)
283        torch.testing.assert_close(
284            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
285        )
286
287        # Serialize the modules and then deserialize
288        input_data = torch.rand(data_shape)
289        buffer = io.BytesIO()
290        torch.jit.save(scripted_linear, buffer)
291        buffer.seek(0)
292        deserialized_linear = torch.jit.load(buffer)
293        buffer = io.BytesIO()
294        torch.jit.save(scripted_linear_clamp_prepacked, buffer)
295        buffer.seek(0)
296        deserialized_linear_clamp_prepacked = torch.jit.load(buffer)
297        ref_result = deserialized_linear(input_data)
298        output_linearprepacked = deserialized_linear_clamp_prepacked(input_data)
299        torch.testing.assert_close(
300            ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3
301        )
302
303    @given(
304        batch_size=st.integers(0, 3),
305        input_channels_per_group=st.integers(1, 32),
306        height=st.integers(5, 64),
307        width=st.integers(5, 64),
308        output_channels_per_group=st.integers(1, 32),
309        groups=st.integers(1, 16),
310        kernel_h=st.integers(1, 7),
311        kernel_w=st.integers(1, 7),
312        stride_h=st.integers(1, 2),
313        stride_w=st.integers(1, 2),
314        pad_h=st.integers(0, 2),
315        pad_w=st.integers(0, 2),
316        dilation=st.integers(1, 2),
317        use_bias=st.booleans(),
318        format=st.sampled_from(
319            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
320        ),
321    )
322    def test_conv2d(
323        self,
324        batch_size,
325        input_channels_per_group,
326        height,
327        width,
328        output_channels_per_group,
329        groups,
330        kernel_h,
331        kernel_w,
332        stride_h,
333        stride_w,
334        pad_h,
335        pad_w,
336        dilation,
337        use_bias,
338        format,
339    ):
340        class Conv2D(torch.nn.Module):
341            def __init__(self, weight, bias, strides, paddings, dilations, groups):
342                super().__init__()
343                self.weight = weight
344                self.bias = bias
345                self.strides = strides
346                self.paddings = paddings
347                self.dilations = dilations
348                self.groups = groups
349
350            def forward(self, x):
351                return F.conv2d(
352                    x,
353                    self.weight,
354                    self.bias,
355                    self.strides,
356                    self.paddings,
357                    self.dilations,
358                    self.groups,
359                )
360
361        class Conv2DPrePacked(torch.nn.Module):
362            def __init__(self, weight, bias, strides, paddings, dilations, groups):
363                super().__init__()
364                self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack(
365                    weight, bias, strides, paddings, dilations, groups
366                )
367
368            def forward(self, x):
369                return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias)
370
371        input_channels = input_channels_per_group * groups
372        output_channels = output_channels_per_group * groups
373        kernels = (kernel_h, kernel_w)
374        strides = (stride_h, stride_w)
375        paddings = (pad_h, pad_w)
376        dilations = (dilation, dilation)
377        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
378        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
379
380        input_data = torch.rand((batch_size, input_channels, height, width))
381        if format is not None:
382            input_data = input_data.contiguous(memory_format=format)
383        weight = torch.rand(
384            (output_channels, input_channels_per_group, kernel_h, kernel_w)
385        )
386        bias = None
387        if use_bias:
388            bias = torch.rand(output_channels)
389
390        scripted_conv2d = torch.jit.script(
391            Conv2D(weight, bias, strides, paddings, dilations, groups)
392        )
393        scripted_conv2d_clamp_prepacked = torch.jit.script(
394            Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups)
395        )
396        ref_result = scripted_conv2d(input_data)
397        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
398        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
399
400        # Serialize the modules and then deserialize
401        input_data = torch.rand((batch_size, input_channels, height, width))
402        if format is not None:
403            input_data = input_data.contiguous(memory_format=format)
404        buffer = io.BytesIO()
405        torch.jit.save(scripted_conv2d, buffer)
406        buffer.seek(0)
407        deserialized_conv2d = torch.jit.load(buffer)
408        buffer = io.BytesIO()
409        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
410        buffer.seek(0)
411        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
412        ref_result = deserialized_conv2d(input_data)
413        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
414        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
415
416    @given(
417        batch_size=st.integers(0, 3),
418        input_channels_per_group=st.integers(1, 32),
419        height=st.integers(5, 64),
420        width=st.integers(5, 64),
421        output_channels_per_group=st.integers(1, 32),
422        groups=st.integers(1, 16),
423        kernel_h=st.integers(1, 7),
424        kernel_w=st.integers(1, 7),
425        stride_h=st.integers(1, 2),
426        stride_w=st.integers(1, 2),
427        pad_h=st.integers(0, 2),
428        pad_w=st.integers(0, 2),
429        output_pad_h=st.integers(0, 2),
430        output_pad_w=st.integers(0, 2),
431        dilation=st.integers(1, 2),
432        use_bias=st.booleans(),
433        format=st.sampled_from(
434            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
435        ),
436    )
437    def test_conv2d_transpose(
438        self,
439        batch_size,
440        input_channels_per_group,
441        height,
442        width,
443        output_channels_per_group,
444        groups,
445        kernel_h,
446        kernel_w,
447        stride_h,
448        stride_w,
449        pad_h,
450        pad_w,
451        output_pad_h,
452        output_pad_w,
453        dilation,
454        use_bias,
455        format,
456    ):
457        class Conv2DT(torch.nn.Module):
458            def __init__(
459                self,
460                weight,
461                bias,
462                strides,
463                paddings,
464                output_paddings,
465                dilations,
466                groups,
467            ):
468                super().__init__()
469                self.weight = weight
470                self.bias = bias
471                self.strides = strides
472                self.paddings = paddings
473                self.output_paddings = output_paddings
474                self.dilations = dilations
475                self.groups = groups
476
477            def forward(self, x):
478                return F.conv_transpose2d(
479                    x,
480                    self.weight,
481                    self.bias,
482                    self.strides,
483                    self.paddings,
484                    self.output_paddings,
485                    self.groups,
486                    self.dilations,
487                )
488
489        class Conv2DTPrePacked(torch.nn.Module):
490            def __init__(
491                self,
492                weight,
493                bias,
494                strides,
495                paddings,
496                output_paddings,
497                dilations,
498                groups,
499            ):
500                super().__init__()
501                self.packed_weight_bias = (
502                    torch.ops.prepacked.conv2d_transpose_clamp_prepack(
503                        weight,
504                        bias,
505                        strides,
506                        paddings,
507                        output_paddings,
508                        dilations,
509                        groups,
510                    )
511                )
512
513            def forward(self, x):
514                return torch.ops.prepacked.conv2d_transpose_clamp_run(
515                    x, self.packed_weight_bias
516                )
517
518        input_channels = input_channels_per_group * groups
519        output_channels = output_channels_per_group * groups
520        kernels = (kernel_h, kernel_w)
521        strides = (stride_h, stride_w)
522        paddings = (pad_h, pad_w)
523        output_paddings = (output_pad_h, output_pad_w)
524        dilations = (dilation, dilation)
525        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
526        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
527        assume((output_pad_h < stride_h) and (output_pad_h < dilation))
528        assume((output_pad_w < stride_w) and (output_pad_w < dilation))
529
530        input_data = torch.rand((batch_size, input_channels, height, width))
531        if format is not None:
532            input_data = input_data.contiguous(memory_format=format)
533        weight = torch.rand(
534            (input_channels, output_channels_per_group, kernel_h, kernel_w)
535        )
536        bias = None
537        if use_bias:
538            bias = torch.rand(output_channels)
539
540        scripted_conv2d = torch.jit.script(
541            Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups)
542        )
543        scripted_conv2d_clamp_prepacked = torch.jit.script(
544            Conv2DTPrePacked(
545                weight, bias, strides, paddings, output_paddings, dilations, groups
546            )
547        )
548        ref_result = scripted_conv2d(input_data)
549        xnnpack_result = scripted_conv2d_clamp_prepacked(input_data)
550        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
551
552        # Serialize the modules and then deserialize
553        input_data = torch.rand((batch_size, input_channels, height, width))
554        if format is not None:
555            input_data = input_data.contiguous(memory_format=format)
556        buffer = io.BytesIO()
557        torch.jit.save(scripted_conv2d, buffer)
558        buffer.seek(0)
559        deserialized_conv2d = torch.jit.load(buffer)
560        buffer = io.BytesIO()
561        torch.jit.save(scripted_conv2d_clamp_prepacked, buffer)
562        buffer.seek(0)
563        deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer)
564        ref_result = deserialized_conv2d(input_data)
565        xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data)
566        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
567
568    @unittest.skip(
569        "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488"
570    )
571    @given(
572        batch_size=st.integers(0, 3),
573        input_channels_per_group=st.integers(1, 32),
574        height=st.integers(5, 64),
575        width=st.integers(5, 64),
576        output_channels_per_group=st.integers(1, 32),
577        groups=st.integers(1, 16),
578        kernel_h=st.integers(1, 7),
579        kernel_w=st.integers(1, 7),
580        stride_h=st.integers(1, 2),
581        stride_w=st.integers(1, 2),
582        pad_h=st.integers(0, 2),
583        pad_w=st.integers(0, 2),
584        dilation=st.integers(1, 2),
585        linear_weight_output_dim=st.integers(2, 64),
586        use_bias=st.booleans(),
587        format=st.sampled_from(
588            [None, torch.preserve_format, torch.contiguous_format, torch.channels_last]
589        ),
590    )
591    def test_combined_model(
592        self,
593        batch_size,
594        input_channels_per_group,
595        height,
596        width,
597        output_channels_per_group,
598        groups,
599        kernel_h,
600        kernel_w,
601        stride_h,
602        stride_w,
603        pad_h,
604        pad_w,
605        dilation,
606        linear_weight_output_dim,
607        use_bias,
608        format,
609    ):
610        class M(torch.nn.Module):
611            def __init__(
612                self,
613                conv_weight,
614                conv_bias,
615                linear_weight,
616                linear_bias,
617                strides,
618                paddings,
619                dilations,
620                groups,
621            ):
622                super().__init__()
623                self.conv_weight = conv_weight
624                self.conv_bias = conv_bias
625                self.linear_weight = linear_weight
626                self.linear_bias = linear_bias
627                self.strides = strides
628                self.paddings = paddings
629                self.dilations = dilations
630                self.groups = groups
631
632            def forward(self, x):
633                o = F.conv2d(
634                    x,
635                    self.conv_weight,
636                    self.conv_bias,
637                    self.strides,
638                    self.paddings,
639                    self.dilations,
640                    self.groups,
641                )
642                o = o.permute([0, 2, 3, 1])
643                o = F.linear(o, self.linear_weight, self.linear_bias)
644                return F.relu(o)
645
646        class MPrePacked(torch.nn.Module):
647            def __init__(
648                self,
649                conv_weight,
650                conv_bias,
651                linear_weight,
652                linear_bias,
653                strides,
654                paddings,
655                dilations,
656                groups,
657            ):
658                super().__init__()
659                self.conv2d_clamp_run_weight_bias = (
660                    torch.ops.prepacked.conv2d_clamp_prepack(
661                        conv_weight, conv_bias, strides, paddings, dilations, groups
662                    )
663                )
664                self.linear_clamp_run_weight_bias = (
665                    torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias)
666                )
667
668            def forward(self, x):
669                o = torch.ops.prepacked.conv2d_clamp_run(
670                    x, self.conv2d_clamp_run_weight_bias
671                )
672                o = o.permute([0, 2, 3, 1])
673                o = torch.ops.prepacked.linear_clamp_run(
674                    o, self.linear_clamp_run_weight_bias
675                )
676                return F.relu(o)
677
678        input_channels = input_channels_per_group * groups
679        output_channels = output_channels_per_group * groups
680        kernels = (kernel_h, kernel_w)
681        strides = (stride_h, stride_w)
682        paddings = (pad_h, pad_w)
683        dilations = (dilation, dilation)
684        assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1)
685        assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1)
686
687        input_data = torch.rand((batch_size, input_channels, height, width))
688        if format is not None:
689            input_data = input_data.contiguous(memory_format=format)
690        conv_weight = torch.rand(
691            (output_channels, input_channels_per_group, kernel_h, kernel_w)
692        )
693        conv_bias = None
694        if use_bias:
695            conv_bias = torch.rand(output_channels)
696
697        # This is done just to find the output shape of the result
698        # so that the shape of weight for the following linear layer
699        # can be determined.
700        result = F.conv2d(
701            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
702        )
703        linear_input_shape = result.shape[1]
704
705        linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape))
706        linear_bias = None
707        if use_bias:
708            linear_bias = torch.rand(linear_weight_output_dim)
709
710        scripted_m = torch.jit.script(
711            M(
712                conv_weight,
713                conv_bias,
714                linear_weight,
715                linear_bias,
716                strides,
717                paddings,
718                dilations,
719                groups,
720            )
721        )
722        scripted_m_prepacked = torch.jit.script(
723            MPrePacked(
724                conv_weight,
725                conv_bias,
726                linear_weight,
727                linear_bias,
728                strides,
729                paddings,
730                dilations,
731                groups,
732            )
733        )
734        ref_result = scripted_m(input_data)
735        xnnpack_result = scripted_m_prepacked(input_data)
736        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
737
738        # Serialize the modules and then deserialize
739        input_data = torch.rand((batch_size, input_channels, height, width))
740        input_data = input_data.contiguous(memory_format=torch.channels_last)
741        buffer = io.BytesIO()
742        torch.jit.save(scripted_m, buffer)
743        buffer.seek(0)
744        deserialized_m = torch.jit.load(buffer)
745        buffer = io.BytesIO()
746        torch.jit.save(scripted_m_prepacked, buffer)
747        buffer.seek(0)
748        deserialized_m_prepacked = torch.jit.load(buffer)
749        ref_result = deserialized_m(input_data)
750        xnnpack_result = deserialized_m_prepacked(input_data)
751        torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
752
753
754@unittest.skipUnless(
755    torch.backends.xnnpack.enabled,
756    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
757)
758@unittest.skipIf(
759    TEST_WITH_TSAN,
760    "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.",
761)
762class TestXNNPACKRewritePass(TestCase):
763    @staticmethod
764    def validate_transformed_module(
765        # To please flake
766        self,
767        pattern_count_map,
768        data_shape,
769        prepack_removal=False,
770        fuse_clamping_ops=False,
771    ):
772        input_data = torch.normal(1, 20, size=data_shape)
773
774        for jit_method in ["script", "trace"]:
775            module_instance = self
776            if jit_method == "script":
777                scripted_model = torch.jit.script(module_instance)
778            else:
779                scripted_model = torch.jit.trace(module_instance, input_data)
780            scripted_model.eval()
781            ref_result = scripted_model(input_data)
782            torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
783            if fuse_clamping_ops or prepack_removal:
784                scripted_model._c = torch._C._freeze_module(scripted_model._c)
785            if fuse_clamping_ops:
786                torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
787            if prepack_removal:
788                torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
789
790            buffer = io.BytesIO()
791            torch.jit.save(scripted_model, buffer)
792            buffer.seek(0)
793            deserialized_scripted_model = torch.jit.load(buffer)
794            for pattern, v in pattern_count_map.items():
795                if v == 0:
796                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
797                elif v == -1:
798                    FileCheck().check_not(pattern).run(
799                        deserialized_scripted_model.graph
800                    )
801                else:
802                    FileCheck().check_count(pattern, v, exactly=True).run(
803                        deserialized_scripted_model.graph
804                    )
805            xnnpack_result = deserialized_scripted_model(input_data)
806            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
807
808    def test_linear(self):
809        data_shape = [2, 3, 32]
810        weight_output_dim = 24
811        weight_shape = (weight_output_dim, data_shape[-1])
812
813        class Linear(torch.nn.Module):
814            def __init__(self) -> None:
815                super().__init__()
816                self.weight = torch.nn.Parameter(
817                    torch.rand(weight_shape), requires_grad=False
818                )
819                self.bias = torch.nn.Parameter(
820                    torch.rand(weight_output_dim), requires_grad=False
821                )
822
823            def forward(self, x):
824                return F.linear(x, self.weight, self.bias)
825
826        class LinearNoBias(torch.nn.Module):
827            def __init__(self) -> None:
828                super().__init__()
829                self.weight = torch.nn.Parameter(
830                    torch.rand(weight_shape), requires_grad=False
831                )
832
833            def forward(self, x):
834                return F.linear(x, self.weight, None)
835
836        # Linear with bias pattern.
837        pattern_count_map = {
838            "Tensor = prim::CallFunction": -1,
839            "prepacked::linear_clamp_prepack": 1,
840            "prepacked::linear_clamp_run": 1,
841        }
842        TestXNNPACKRewritePass.validate_transformed_module(
843            Linear(), pattern_count_map, data_shape
844        )
845        TestXNNPACKRewritePass.validate_transformed_module(
846            LinearNoBias(), pattern_count_map, data_shape
847        )
848
849        # Conv params
850        batch_size = 2
851        input_channels_per_group = 6
852        height = 16
853        width = 16
854        output_channels_per_group = 6
855        groups = 4
856        kernel_h = kernel_w = 3
857        stride_h = stride_w = 1
858        pad_h = pad_w = 1
859        output_pad_h = output_pad_w = 0
860        dilation = 1
861        input_channels = input_channels_per_group * groups
862        output_channels = output_channels_per_group * groups
863        kernels = (kernel_h, kernel_w)
864        strides = (stride_h, stride_w)
865        paddings = (pad_h, pad_w)
866        output_paddings = (output_pad_h, output_pad_w)
867        dilations = (dilation, dilation)
868        conv_weight_shape = (
869            output_channels,
870            input_channels_per_group,
871            kernel_h,
872            kernel_w,
873        )
874        conv_transpose_weight_shape = (
875            input_channels,
876            output_channels_per_group,
877            kernel_h,
878            kernel_w,
879        )
880        conv_bias_shape = output_channels
881
882        class Conv2D(torch.nn.Module):
883            def __init__(self) -> None:
884                super().__init__()
885                self.weight = torch.nn.Parameter(
886                    torch.rand(conv_weight_shape), requires_grad=False
887                )
888                self.bias = torch.nn.Parameter(
889                    torch.rand(conv_bias_shape), requires_grad=False
890                )
891                self.strides = strides
892                self.paddings = paddings
893                self.dilations = dilations
894                self.groups = groups
895
896            def forward(self, x):
897                return F.conv2d(
898                    x,
899                    self.weight,
900                    self.bias,
901                    self.strides,
902                    self.paddings,
903                    self.dilations,
904                    self.groups,
905                )
906
907        class Conv2DT(torch.nn.Module):
908            def __init__(self) -> None:
909                super().__init__()
910                self.weight = torch.nn.Parameter(
911                    torch.rand(conv_transpose_weight_shape), requires_grad=False
912                )
913                self.bias = torch.nn.Parameter(
914                    torch.rand(conv_bias_shape), requires_grad=False
915                )
916                self.strides = strides
917                self.paddings = paddings
918                self.output_paddings = output_paddings
919                self.dilations = dilations
920                self.groups = groups
921
922            def forward(self, x):
923                return F.conv_transpose2d(
924                    x,
925                    self.weight,
926                    self.bias,
927                    self.strides,
928                    self.paddings,
929                    self.output_paddings,
930                    self.groups,
931                    self.dilations,
932                )
933
934        data_shape = (batch_size, input_channels, height, width)
935        pattern_count_map = {
936            "Tensor = aten::conv2d": -1,
937            "prepacked::conv2d_clamp_prepack": 1,
938            "prepacked::conv2d_clamp_run": 1,
939        }
940        TestXNNPACKRewritePass.validate_transformed_module(
941            Conv2D(), pattern_count_map, data_shape
942        )
943
944        transpose_data_shape = (batch_size, input_channels, height, width)
945        transpose_pattern_count_map = {
946            "Tensor = aten::conv_transpose2d": -1,
947            "prepacked::conv2d_transpose_clamp_prepack": 1,
948            "prepacked::conv2d_transpose_clamp_run": 1,
949        }
950        TestXNNPACKRewritePass.validate_transformed_module(
951            Conv2DT(), transpose_pattern_count_map, data_shape
952        )
953
954        input_data = torch.rand((batch_size, input_channels, height, width))
955        conv_weight = torch.rand(
956            (output_channels, input_channels_per_group, kernel_h, kernel_w)
957        )
958        conv_bias = torch.rand(output_channels)
959        result = F.conv2d(
960            input_data, conv_weight, conv_bias, strides, paddings, dilations, groups
961        )
962        linear_input_shape = result.shape[1]
963        linear_weight_shape = (weight_output_dim, linear_input_shape)
964
965        class M(torch.nn.Module):
966            def __init__(self, activation_fn=F.relu):
967                super().__init__()
968                self.conv_weight = torch.nn.Parameter(
969                    torch.rand(conv_weight_shape), requires_grad=False
970                )
971                self.conv_bias = torch.nn.Parameter(
972                    torch.rand(conv_bias_shape), requires_grad=False
973                )
974                self.linear_weight = torch.nn.Parameter(
975                    torch.rand(linear_weight_shape), requires_grad=False
976                )
977                self.linear_bias = torch.nn.Parameter(
978                    torch.rand(weight_output_dim), requires_grad=False
979                )
980                self.strides = strides
981                self.paddings = paddings
982                self.dilations = dilations
983                self.groups = groups
984                self.activation_fn = activation_fn
985
986            def forward(self, x):
987                o = F.conv2d(
988                    x,
989                    self.conv_weight,
990                    self.conv_bias,
991                    self.strides,
992                    self.paddings,
993                    self.dilations,
994                    self.groups,
995                )
996                o = self.activation_fn(o)
997                o = o.permute([0, 2, 3, 1])
998                o = F.linear(o, self.linear_weight, self.linear_bias)
999                return self.activation_fn(o)
1000
1001        pattern_count_map = {
1002            "Tensor = aten::conv2d": -1,
1003            "prepacked::conv2d_clamp_prepack": 1,
1004            "prepacked::conv2d_clamp_run": 1,
1005            "prepacked::linear_clamp_prepack": 1,
1006            "prepacked::linear_clamp_run": 1,
1007        }
1008        TestXNNPACKRewritePass.validate_transformed_module(
1009            M(), pattern_count_map, data_shape
1010        )
1011        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1012        pattern_count_map["Tensor = prim::CallFunction"] = -1
1013        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1014        TestXNNPACKRewritePass.validate_transformed_module(
1015            M(), pattern_count_map, data_shape, prepack_removal=True
1016        )
1017
1018        # Not inplace relu fusion test.
1019        pattern_count_map = {
1020            "aten::relu": 2,
1021            "prepacked::conv2d_clamp_prepack": -1,
1022            "prepacked::conv2d_clamp_run": 1,
1023            "prepacked::linear_clamp_prepack": -1,
1024            "prepacked::linear_clamp_run": 1,
1025        }
1026        TestXNNPACKRewritePass.validate_transformed_module(
1027            M(), pattern_count_map, data_shape, prepack_removal=True
1028        )
1029        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1030        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1031        pattern_count_map["aten::relu"] = -1
1032        TestXNNPACKRewritePass.validate_transformed_module(
1033            M(),
1034            pattern_count_map,
1035            data_shape,
1036            prepack_removal=True,
1037            fuse_clamping_ops=True,
1038        )
1039
1040        # Inplace relu fusion test.
1041        pattern_count_map = {
1042            "aten::relu": 2,
1043            "prepacked::conv2d_clamp_prepack": -1,
1044            "prepacked::conv2d_clamp_run": 1,
1045            "prepacked::linear_clamp_prepack": -1,
1046            "prepacked::linear_clamp_run": 1,
1047        }
1048        TestXNNPACKRewritePass.validate_transformed_module(
1049            M(F.relu_), pattern_count_map, data_shape, prepack_removal=True
1050        )
1051        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1052        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1053        pattern_count_map["aten::relu"] = -1
1054        TestXNNPACKRewritePass.validate_transformed_module(
1055            M(F.relu_),
1056            pattern_count_map,
1057            data_shape,
1058            prepack_removal=True,
1059            fuse_clamping_ops=True,
1060        )
1061
1062        # Not inplace hardtanh fusion test.
1063        pattern_count_map = {
1064            "aten::hardtanh": 2,
1065            "prepacked::conv2d_clamp_prepack": -1,
1066            "prepacked::conv2d_clamp_run": 1,
1067            "prepacked::linear_clamp_prepack": -1,
1068            "prepacked::linear_clamp_run": 1,
1069        }
1070        TestXNNPACKRewritePass.validate_transformed_module(
1071            M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True
1072        )
1073        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1074        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1075        pattern_count_map["aten::hardtanh"] = -1
1076        TestXNNPACKRewritePass.validate_transformed_module(
1077            M(F.hardtanh),
1078            pattern_count_map,
1079            data_shape,
1080            prepack_removal=True,
1081            fuse_clamping_ops=True,
1082        )
1083
1084        # Inplace hardtanh fusion test.
1085        pattern_count_map = {
1086            "aten::hardtanh_": 2,
1087            "prepacked::conv2d_clamp_prepack": -1,
1088            "prepacked::conv2d_clamp_run": 1,
1089            "prepacked::linear_clamp_prepack": -1,
1090            "prepacked::linear_clamp_run": 1,
1091        }
1092        TestXNNPACKRewritePass.validate_transformed_module(
1093            M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True
1094        )
1095        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
1096        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
1097        pattern_count_map["aten::hardtanh_"] = -1
1098        TestXNNPACKRewritePass.validate_transformed_module(
1099            M(F.hardtanh_),
1100            pattern_count_map,
1101            data_shape,
1102            prepack_removal=True,
1103            fuse_clamping_ops=True,
1104        )
1105
1106        class MFusionAntiPattern(torch.nn.Module):
1107            def __init__(self) -> None:
1108                super().__init__()
1109                self.linear_weight = torch.nn.Parameter(
1110                    torch.rand(linear_weight_shape), requires_grad=False
1111                )
1112                self.linear_bias = torch.nn.Parameter(
1113                    torch.rand(weight_output_dim), requires_grad=False
1114                )
1115                self.strides = strides
1116                self.paddings = paddings
1117                self.dilations = dilations
1118                self.groups = groups
1119
1120            def forward(self, x):
1121                o = F.linear(x, self.linear_weight, self.linear_bias)
1122                o = F.relu(o)
1123                o = F.hardtanh(o)
1124                return o
1125
1126        # Unfusable hardtanh.
1127        pattern_count_map = {
1128            "aten::hardtanh": 1,  # hardtanh cannot be.
1129            "aten::relu": -1,  # relu is fused.
1130            "prepacked::linear_clamp_prepack": -1,
1131            "prepacked::linear_clamp_run": 1,
1132        }
1133        TestXNNPACKRewritePass.validate_transformed_module(
1134            MFusionAntiPattern(),
1135            pattern_count_map,
1136            (16, linear_weight_shape[1]),
1137            prepack_removal=True,
1138            fuse_clamping_ops=True,
1139        )
1140
1141        class MFusionAntiPatternParamMinMax(torch.nn.Module):
1142            def __init__(self) -> None:
1143                super().__init__()
1144                self.linear_weight = torch.nn.Parameter(
1145                    torch.rand(linear_weight_shape), requires_grad=False
1146                )
1147                self.linear_bias = torch.nn.Parameter(
1148                    torch.rand(weight_output_dim), requires_grad=False
1149                )
1150                self.strides = strides
1151                self.paddings = paddings
1152                self.dilations = dilations
1153                self.groups = groups
1154
1155            def forward(self, x):
1156                min = x[0, 0]
1157                max = min + 10
1158                o = F.linear(x, self.linear_weight, self.linear_bias)
1159                o = F.hardtanh(o, min, max)
1160                return o
1161
1162        # Unfusable hardtanh.
1163        pattern_count_map = {
1164            "aten::hardtanh": 1,  # hardtanh cannot be.
1165            "prepacked::linear_clamp_prepack": -1,
1166            "prepacked::linear_clamp_run": 1,
1167        }
1168        TestXNNPACKRewritePass.validate_transformed_module(
1169            MFusionAntiPatternParamMinMax(),
1170            pattern_count_map,
1171            (16, linear_weight_shape[1]),
1172            prepack_removal=True,
1173            fuse_clamping_ops=True,
1174        )
1175
1176    def test_decomposed_linear(self):
1177        data_shape = [2, 32]
1178        weight_output_dim = 24
1179        weight_shape = (weight_output_dim, data_shape[-1])
1180
1181        class DecomposedLinearAddmm(torch.nn.Module):
1182            def __init__(self) -> None:
1183                super().__init__()
1184                self.weight = torch.nn.Parameter(
1185                    torch.rand(weight_shape), requires_grad=False
1186                )
1187                self.bias = torch.nn.Parameter(
1188                    torch.rand(weight_output_dim), requires_grad=False
1189                )
1190
1191            def forward(self, x):
1192                weight_t = self.weight.t()
1193                return torch.addmm(self.bias, x, weight_t)
1194
1195        class DecomposedLinearMatmulAdd(torch.nn.Module):
1196            def __init__(self) -> None:
1197                super().__init__()
1198                self.weight = torch.nn.Parameter(
1199                    torch.rand(weight_shape), requires_grad=False
1200                )
1201                self.bias = torch.nn.Parameter(
1202                    torch.rand(weight_output_dim), requires_grad=False
1203                )
1204
1205            def forward(self, x):
1206                weight_t = self.weight.t()
1207                y = torch.matmul(x, weight_t)
1208                res = y.add_(self.bias)
1209                return res
1210
1211        class DecomposedLinearMatmul(torch.nn.Module):
1212            def __init__(self) -> None:
1213                super().__init__()
1214                self.weight = torch.nn.Parameter(
1215                    torch.rand(weight_shape), requires_grad=False
1216                )
1217                self.bias = torch.nn.Parameter(
1218                    torch.rand(weight_output_dim), requires_grad=False
1219                )
1220
1221            def forward(self, x):
1222                weight_t = self.weight.t()
1223                res = torch.matmul(x, weight_t)
1224                return res
1225
1226        # Linear with bias pattern.
1227        pattern_count_map = {
1228            "Tensor = prim::CallFunction": -1,
1229            "prepacked::linear_clamp_prepack": 1,
1230            "prepacked::linear_clamp_run": 1,
1231        }
1232        TestXNNPACKRewritePass.validate_transformed_module(
1233            DecomposedLinearAddmm(), pattern_count_map, data_shape
1234        )
1235        TestXNNPACKRewritePass.validate_transformed_module(
1236            DecomposedLinearMatmulAdd(), pattern_count_map, data_shape
1237        )
1238        TestXNNPACKRewritePass.validate_transformed_module(
1239            DecomposedLinearMatmul(), pattern_count_map, data_shape
1240        )
1241
1242
1243@unittest.skipUnless(
1244    torch.backends.xnnpack.enabled,
1245    " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.",
1246)
1247@unittest.skipIf(
1248    TEST_WITH_TSAN,
1249    "TSAN is not fork-safe since we're forking in a multi-threaded environment",
1250)
1251class TestXNNPACKConv1dTransformPass(TestCase):
1252    @staticmethod
1253    def validate_transform_conv1d_to_conv2d(
1254        self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape
1255    ):
1256        input_data = torch.normal(1, 20, size=data_shape)
1257
1258        for jit_method in ["script", "trace"]:
1259            module_instance = self
1260            if jit_method == "script":
1261                scripted_model = torch.jit.script(module_instance)
1262            else:
1263                scripted_model = torch.jit.trace(module_instance, input_data)
1264            scripted_model.eval()
1265            ref_result = scripted_model(input_data)
1266            torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
1267            optimized_scripted_model = optimize_for_mobile(scripted_model)
1268
1269            buffer = io.BytesIO()
1270            torch.jit.save(scripted_model, buffer)
1271            buffer.seek(0)
1272            deserialized_scripted_model = torch.jit.load(buffer)
1273
1274            for pattern, v in pattern_count_transformed_map.items():
1275                if v == 0:
1276                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
1277                elif v == -1:
1278                    FileCheck().check_not(pattern).run(
1279                        deserialized_scripted_model.graph
1280                    )
1281                else:
1282                    FileCheck().check_count(pattern, v, exactly=True).run(
1283                        deserialized_scripted_model.graph
1284                    )
1285            transformed_result = deserialized_scripted_model(input_data)
1286            torch.testing.assert_close(
1287                ref_result, transformed_result, rtol=1e-2, atol=1e-3
1288            )
1289
1290            optimized_buffer = io.BytesIO()
1291            torch.jit.save(optimized_scripted_model, optimized_buffer)
1292            optimized_buffer.seek(0)
1293            deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer)
1294
1295            for pattern, v in pattern_count_optimized_map.items():
1296                if v == 0:
1297                    FileCheck().check(pattern).run(
1298                        deserialized_optimized_scripted_model.graph
1299                    )
1300                elif v == -1:
1301                    FileCheck().check_not(pattern).run(
1302                        deserialized_optimized_scripted_model.graph
1303                    )
1304                else:
1305                    FileCheck().check_count(pattern, v, exactly=True).run(
1306                        deserialized_optimized_scripted_model.graph
1307                    )
1308            xnnpack_result = deserialized_optimized_scripted_model(input_data)
1309            torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
1310
1311    @unittest.skipIf(IS_FBCODE, "T137513244")
1312    def test_conv1d_basic(self):
1313        batch_size_list = range(1, 3)
1314        input_channels_per_group_list = range(10, 12)
1315        width_list = range(10, 12)
1316        output_channels_per_group_list = range(10, 12)
1317        groups_list = range(1, 3)
1318        kernel_list = range(1, 4)
1319        stride_list = range(1, 3)
1320        padding_list = range(0, 3)
1321        dilation_list = range(1, 3)
1322
1323        for hparams in itertools.product(
1324            batch_size_list,
1325            input_channels_per_group_list,
1326            width_list,
1327            output_channels_per_group_list,
1328            groups_list,
1329            kernel_list,
1330            stride_list,
1331            padding_list,
1332            dilation_list,
1333        ):
1334            (
1335                batch_size,
1336                input_channels_per_group,
1337                width,
1338                output_channels_per_group,
1339                groups,
1340                kernel,
1341                stride,
1342                padding,
1343                dilation,
1344            ) = hparams
1345
1346            input_channels = input_channels_per_group * groups
1347            output_channels = output_channels_per_group * groups
1348            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1349            conv_bias_shape = output_channels
1350
1351            class Conv1D(torch.nn.Module):
1352                def __init__(self) -> None:
1353                    super().__init__()
1354                    self.weight = torch.nn.Parameter(
1355                        torch.rand(conv_weight_shape), requires_grad=False
1356                    )
1357                    self.bias = torch.nn.Parameter(
1358                        torch.rand(conv_bias_shape), requires_grad=False
1359                    )
1360                    self.stride = stride
1361                    self.padding = padding
1362                    self.dilation = dilation
1363                    self.groups = groups
1364
1365                def forward(self, x):
1366                    return F.conv1d(
1367                        x,
1368                        self.weight,
1369                        self.bias,
1370                        self.stride,
1371                        self.padding,
1372                        self.dilation,
1373                        self.groups,
1374                    )
1375
1376            data_shape = (batch_size, input_channels, width)
1377            pattern_count_transformed_map = {
1378                "Tensor = aten::conv1d": -1,
1379                "Tensor = aten::conv2d": 1,
1380            }
1381            pattern_count_optimized_map = {
1382                "Tensor = aten::conv1d": -1,
1383                "Tensor = aten::conv2d": -1,
1384                "prepacked::conv2d_clamp_prepack": -1,
1385                "prepacked::conv2d_clamp_run": 1,
1386            }
1387
1388            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1389                Conv1D(),
1390                pattern_count_transformed_map,
1391                pattern_count_optimized_map,
1392                data_shape,
1393            )
1394
1395    # See https://github.com/pytorch/pytorch/issues/46066
1396    @slowTest
1397    def test_conv1d_with_relu_fc(self):
1398        batch_size_list = range(1, 3)
1399        input_channels_per_group_list = range(10, 12)
1400        width_list = range(10, 12)
1401        output_channels_per_group_list = range(10, 12)
1402        groups_list = range(1, 3)
1403        kernel_list = range(1, 4)
1404        stride_list = range(1, 3)
1405        padding_list = range(0, 3)
1406        dilation_list = range(1, 3)
1407        output_features_list = range(1, 3)
1408
1409        for hparams in itertools.product(
1410            batch_size_list,
1411            input_channels_per_group_list,
1412            width_list,
1413            output_channels_per_group_list,
1414            groups_list,
1415            kernel_list,
1416            stride_list,
1417            padding_list,
1418            dilation_list,
1419            output_features_list,
1420        ):
1421            (
1422                batch_size,
1423                input_channels_per_group,
1424                width,
1425                output_channels_per_group,
1426                groups,
1427                kernel,
1428                stride,
1429                padding,
1430                dilation,
1431                output_features,
1432            ) = hparams
1433
1434            input_channels = input_channels_per_group * groups
1435            output_channels = output_channels_per_group * groups
1436            conv_weight_shape = (output_channels, input_channels_per_group, kernel)
1437            conv_bias_shape = output_channels
1438            conv_output_width = (
1439                int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1
1440            )
1441            fc_weight_shape = (output_features, output_channels * conv_output_width)
1442            fc_bias_shape = output_features
1443
1444            class Net(torch.nn.Module):
1445                def __init__(self) -> None:
1446                    super().__init__()
1447                    self.conv_weight = torch.nn.Parameter(
1448                        torch.rand(conv_weight_shape), requires_grad=False
1449                    )
1450                    self.conv_bias = torch.nn.Parameter(
1451                        torch.rand(conv_bias_shape), requires_grad=False
1452                    )
1453                    self.stride = stride
1454                    self.padding = padding
1455                    self.dilation = dilation
1456                    self.groups = groups
1457
1458                    self.fc_weight = torch.nn.Parameter(
1459                        torch.rand(fc_weight_shape), requires_grad=False
1460                    )
1461                    self.fc_bias = torch.nn.Parameter(
1462                        torch.rand(fc_bias_shape), requires_grad=False
1463                    )
1464
1465                def forward(self, x):
1466                    x = F.conv1d(
1467                        x,
1468                        self.conv_weight,
1469                        self.conv_bias,
1470                        self.stride,
1471                        self.padding,
1472                        self.dilation,
1473                        self.groups,
1474                    )
1475                    x = F.relu(x)
1476                    x = x.view(x.size(0), -1)
1477                    x = F.linear(x, self.fc_weight, self.fc_bias)
1478                    return x
1479
1480            data_shape = (batch_size, input_channels, width)
1481            pattern_count_transformed_map = {
1482                "Tensor = aten::conv1d": -1,
1483                "Tensor = aten::conv2d": 1,
1484            }
1485            pattern_count_optimized_map = {
1486                "Tensor = aten::conv1d": -1,
1487                "Tensor = aten::conv2d": -1,
1488                "prepacked::conv2d_clamp_prepack": -1,
1489                "prepacked::conv2d_clamp_run": 1,
1490            }
1491            TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d(
1492                Net(),
1493                pattern_count_transformed_map,
1494                pattern_count_optimized_map,
1495                data_shape,
1496            )
1497
1498
1499if __name__ == "__main__":
1500    run_tests()
1501