xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/conv1d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.partition.config.xnnpack_config import (
11    ConfigPrecisionType,
12)
13from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
14from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
15
16from executorch.backends.xnnpack.test.tester import RunPasses, Tester
17from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower
18from executorch.exir.passes.constant_prop_pass import constant_prop_pass
19
20
21class TestConv1d(unittest.TestCase):
22    class Conv1d(torch.nn.Module):
23        def __init__(self, dtype: torch.dtype = torch.float):
24            groups = 1
25            stride = (2,)
26            padding = (1,)
27            dilation = (1,)
28            in_channels = 2
29            out_channels = 1
30            kernel_size = (3,)
31
32            super().__init__()
33
34            self.conv1d = torch.nn.Conv1d(
35                in_channels=in_channels,
36                out_channels=out_channels,
37                kernel_size=kernel_size,
38                stride=stride,
39                padding=padding,
40                groups=groups,
41                dilation=dilation,
42                bias=True,
43            ).to(dtype)
44
45        def forward(self, x):
46            return self.conv1d(x)
47
48    class Conv1dBatchNormSequential(torch.nn.Module):
49        def __init__(self):
50            groups = 1
51            stride = [1]
52            padding = [1]
53            dilation = [1]
54            in_channels = 2
55            out_channels = 2
56            kernel_size = (3,)
57
58            super().__init__()
59            self.conv1 = torch.nn.Conv1d(
60                in_channels=in_channels,
61                out_channels=out_channels,
62                kernel_size=kernel_size,
63                stride=stride,
64                padding=padding,
65                groups=groups,
66                dilation=dilation,
67                bias=True,
68            )
69            self.bn1 = randomize_bn(num_features=in_channels, dimensionality=1)
70            self.conv2 = torch.nn.Conv1d(
71                in_channels=in_channels,
72                out_channels=out_channels,
73                kernel_size=kernel_size,
74                stride=stride,
75                padding=padding,
76                groups=groups,
77                dilation=dilation,
78                bias=True,
79            )
80            self.bn2 = randomize_bn(num_features=in_channels, dimensionality=1)
81
82        def forward(self, x):
83            y = self.conv1(x)
84            y = self.bn1(y)
85            z = self.conv2(y)
86            z = self.bn2(z)
87            z = torch.add(y, z)
88            return z
89
90    def _test_conv1d(
91        self,
92        module,
93        inputs,
94        conv_count,
95        quantized=False,
96        dynamic_shape=None,
97        passes=None,
98        stage=None,
99        skip_to_executorch=False,
100    ):
101        tester = (
102            (
103                Tester(module, inputs, dynamic_shape).quantize()
104                if quantized
105                else Tester(module, inputs)
106            )
107            .export()
108            .run_passes(passes)
109            .check_count({"torch.ops.aten.conv1d.default": conv_count})
110            .to_edge_transform_and_lower(stage)
111            .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
112            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
113        )
114        # For some tests we want to skip to_executorch because otherwise it will require the
115        # quantized operators to be loaded and we don't want to do that in the test.
116        if not skip_to_executorch:
117            tester.to_executorch().serialize().run_method_and_compare_outputs()
118
119    def test_fp16_conv1d(self):
120        inputs = (torch.randn(2, 2, 4).to(torch.float16),)
121        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
122        self._test_conv1d(
123            self.Conv1d(dtype=torch.float16),
124            inputs,
125            conv_count=1,
126            dynamic_shape=dynamic_shapes,
127        )
128
129    def test_fp32_conv1d(self):
130        inputs = (torch.randn(2, 2, 4),)
131        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
132        self._test_conv1d(self.Conv1d(), inputs, 1, dynamic_shape=dynamic_shapes)
133
134    def test_fp32_conv1d_batchnorm_seq(self):
135        inputs = (torch.randn(2, 2, 4),)
136        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
137        self._test_conv1d(
138            self.Conv1dBatchNormSequential(), inputs, 2, dynamic_shape=dynamic_shapes
139        )
140
141    def test_qs8_conv1d(self):
142        inputs = (torch.randn(2, 2, 4),)
143        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
144        self._test_conv1d(
145            self.Conv1d(), inputs, 1, quantized=True, dynamic_shape=dynamic_shapes
146        )
147
148    def test_qs8_conv1d_batchnorm_seq(self):
149        inputs = (torch.randn(2, 2, 4),)
150        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
151        self._test_conv1d(
152            self.Conv1dBatchNormSequential(),
153            inputs,
154            2,
155            quantized=True,
156            dynamic_shape=dynamic_shapes,
157        )
158
159    def test_qs8_conv1d_with_floating_point_partitioner(self):
160        inputs = (torch.randn(2, 2, 4),)
161        dynamic_shapes = ({0: torch.export.Dim("batch", min=2, max=10)},)
162        self._test_conv1d(
163            self.Conv1d(),
164            inputs,
165            1,
166            quantized=True,
167            dynamic_shape=dynamic_shapes,
168            stage=ToEdgeTransformAndLower(
169                partitioners=[
170                    XnnpackPartitioner(config_precisions=ConfigPrecisionType.FP32)
171                ]
172            ),
173            passes=RunPasses(pass_functions=[constant_prop_pass]),
174            skip_to_executorch=True,
175        )
176