xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/cat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestCat(unittest.TestCase):
14    class Cat2(torch.nn.Module):
15        def forward(self, arg1, arg2):
16            xs = [arg1, arg2]
17            x = torch.cat(xs)
18            return x + x  # Quantize by propagation.
19
20    class Cat3(torch.nn.Module):
21        def forward(self, arg1, arg2, arg3):
22            xs = [arg1, arg2, arg3]
23            x = torch.cat(xs)
24            return x + x  # Quantize by propagation.
25
26    class Cat4(torch.nn.Module):
27        def forward(self, arg1, arg2, arg3, arg4):
28            xs = [arg1, arg2, arg3, arg4]
29            x = torch.cat(xs)
30            return x + x  # Quantize by propagation.
31
32    class Cat5(torch.nn.Module):
33        def forward(self, arg1, arg2, arg3, arg4, arg5):
34            xs = [arg1, arg2, arg3, arg4, arg5]
35            x = torch.cat(xs)
36            return x + x  # Quantize by propagation.
37
38    def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
39        for legacy_mode in (True, False):
40            tester = Tester(module, inputs)
41
42            if quant:
43                tester.quantize()
44
45            tester.export().check_count({"torch.ops.aten.cat": 1})
46            tester.dump_artifact()
47
48            if quant:
49                # Expect multiple quantize ops - one per input, cat, and add.
50                tester.check_node_count(
51                    {
52                        # Q/DQ pair for each input and quantized op. For most tests, there are
53                        # two quantized ops - cat and add.
54                        torch.ops.quantized_decomposed.quantize_per_tensor.default: (
55                            cat_num + quant_ops
56                        )
57                    }
58                )
59
60            if legacy_mode:
61                tester.to_edge()
62                tester.partition()
63            else:
64                tester.to_edge_transform_and_lower()
65
66            if quant:
67                tester.check_not(["torch.ops.quantized_decomposed"])
68
69            (
70                tester.check_count(
71                    {"torch.ops.higher_order.executorch_call_delegate": 1}
72                )
73                .check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
74                .to_executorch()
75                .serialize()
76                .run_method_and_compare_outputs()
77            )
78
79    def test_fp16_cat2(self):
80        """
81        Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
82        """
83        inputs = (
84            torch.randn(1, 2, 3).to(torch.float16),
85            torch.randn(3, 2, 3).to(torch.float16),
86        )
87        self._test_cat(self.Cat2(), inputs)
88
89    def test_fp16_cat3(self):
90        """
91        Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
92        """
93        inputs = (
94            torch.randn(1, 2, 3).to(torch.float16),
95            torch.randn(3, 2, 3).to(torch.float16),
96            torch.randn(2, 2, 3).to(torch.float16),
97        )
98        self._test_cat(self.Cat3(), inputs)
99
100    def test_fp16_cat4(self):
101        """
102        Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
103        """
104        inputs = (
105            torch.randn(1, 2, 3).to(torch.float16),
106            torch.randn(3, 2, 3).to(torch.float16),
107            torch.randn(2, 2, 3).to(torch.float16),
108            torch.randn(5, 2, 3).to(torch.float16),
109        )
110        self._test_cat(self.Cat4(), inputs)
111
112    def test_fp32_cat2(self):
113        inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
114        self._test_cat(self.Cat2(), inputs)
115
116    def test_fp32_cat3(self):
117        inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
118        self._test_cat(self.Cat3(), inputs)
119
120    def test_fp32_cat4(self):
121        inputs = (
122            torch.randn(1, 2, 3),
123            torch.randn(3, 2, 3),
124            torch.randn(2, 2, 3),
125            torch.randn(5, 2, 3),
126        )
127        self._test_cat(self.Cat4(), inputs)
128
129    def test_qs8_cat2(self):
130        inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3))
131        self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True)
132
133    def test_qs8_cat3(self):
134        inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3))
135        self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True)
136
137    def test_qs8_cat4(self):
138        inputs = (
139            torch.randn(1, 2, 3),
140            torch.randn(3, 2, 3),
141            torch.randn(2, 2, 3),
142            torch.randn(5, 2, 3),
143        )
144        self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True)
145
146    def test_fp32_cat_unsupported(self):
147        """
148        XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
149        """
150        inputs = (
151            torch.randn(1, 2, 3),
152            torch.randn(3, 2, 3),
153            torch.randn(2, 2, 3),
154            torch.randn(5, 2, 3),
155            torch.randn(1, 2, 3),
156        )
157        (
158            Tester(self.Cat5(), inputs)
159            .export()
160            .check_count({"torch.ops.aten.cat": 1})
161            .to_edge_transform_and_lower()
162            .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
163        )
164
165    def test_fp32_cat_unsupported_legacy_mode(self):
166        """
167        XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
168        """
169        inputs = (
170            torch.randn(1, 2, 3),
171            torch.randn(3, 2, 3),
172            torch.randn(2, 2, 3),
173            torch.randn(5, 2, 3),
174            torch.randn(1, 2, 3),
175        )
176        (
177            Tester(self.Cat5(), inputs)
178            .export()
179            .check_count({"torch.ops.aten.cat": 1})
180            .to_edge()
181            .partition()
182            .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
183        )
184
185    class CatNegativeDim(torch.nn.Module):
186        def __init__(self):
187            super().__init__()
188
189        def forward(self, x, y):
190            return torch.cat([x, y], -1)
191
192    def test_fp32_cat_negative_dim(self):
193        inputs = (torch.randn(3, 2, 3), torch.randn(3, 2, 1))
194        self._test_cat(self.CatNegativeDim(), inputs)
195
196    class CatNhwc(torch.nn.Module):
197        def __init__(self):
198            super().__init__()
199            self.conv = torch.nn.Conv2d(
200                in_channels=1,
201                out_channels=3,
202                kernel_size=(3, 3),
203                padding=1,
204                bias=False,
205            )
206
207        def forward(self, x, y):
208            x = self.conv(x)
209            z = torch.concatenate((y, x, y, x), 1)
210            return z + z
211
212    @unittest.skip("T172862540 - Runtime failure.")
213    def _test_qs8_cat_nhwc(self):
214        inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3))
215        self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=3)
216
217    class CatNhwc2(torch.nn.Module):
218        def __init__(self):
219            super().__init__()
220            self.conv = torch.nn.Conv2d(
221                in_channels=1,
222                out_channels=3,
223                kernel_size=(3, 3),
224                padding=1,
225                bias=False,
226            )
227
228        def forward(self, x, y):
229            x = self.conv(x)
230            y = self.conv(y)
231            z = torch.concatenate((y, x, y, x), 3)
232            return z + z
233
234    @unittest.skip("T172862540 - Runtime failure.")
235    def _test_qs8_cat_nhwc2(self):
236        inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3))
237        self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4)
238