xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/slice_copy.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestSliceCopy(unittest.TestCase):
14    def _test_slice_copy(self, module, inputs, copy_count=1, edge_copy_count=1):
15        (
16            Tester(module, inputs)
17            .export()
18            .check_count({"torch.ops.aten.slice.Tensor": copy_count})
19            .to_edge()
20            .check_count(
21                {
22                    "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": edge_copy_count
23                }
24            )
25            .partition()
26            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
27            .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"])
28            .to_executorch()
29            .serialize()
30            .run_method_and_compare_outputs()
31        )
32
33    def test_fp16_slice_copy(self):
34        class SliceCopy(torch.nn.Module):
35            def forward(self, x):
36                return x[1:3, -2:, :-1]
37
38        inputs = (torch.randn(5, 5, 5).to(torch.float16),)
39        self._test_slice_copy(SliceCopy(), inputs, 3, 3)
40
41    def test_fp32_slice_copy(self):
42        class SliceCopy(torch.nn.Module):
43            def forward(self, x):
44                return x[1:3, -2:, :-1]
45
46        inputs = (torch.randn(5, 5, 5),)
47        self._test_slice_copy(SliceCopy(), inputs, 3, 3)
48
49    def test_fp32_slice_copy_memory_format(self):
50        class ConvSlice(torch.nn.Module):
51            def __init__(self):
52                super().__init__()
53                self.conv = torch.nn.Conv2d(
54                    in_channels=1,
55                    out_channels=3,
56                    kernel_size=(3, 3),
57                    padding=1,
58                    bias=False,
59                )
60
61            def forward(self, x):
62                y = self.conv(x)
63                return y[:, :, 2:3, -2:]
64
65        inputs = (torch.randn(1, 1, 3, 3),)
66        # Note that two of the slices are optimized away as they are identity.
67        self._test_slice_copy(ConvSlice(), inputs, 4, 2)
68
69    def test_fp32_slice_copy_stride_non_1(self):
70        """
71        XNNPACK does not support strided slicing.
72        """
73
74        class Slice(torch.nn.Module):
75            def forward(self, x):
76                return x[:3:2, :, :]
77
78        module = Slice()
79        inputs = (torch.randn(5, 5, 5),)
80        (
81            Tester(module, inputs)
82            .export()
83            .check_count({"torch.ops.aten.slice.Tensor": 3})
84            .to_edge_transform_and_lower()
85            .check_not(["torch.ops.higher_order.executorch_call_delegate"])
86        )
87
88    def test_fp32_slice_copy_dim_0(self):
89        """
90        XNNPACK does not support 0-size dims.
91        """
92
93        class Slice(torch.nn.Module):
94            def forward(self, x):
95                return x[-1:3, 2:, 3:3]
96
97        module = Slice()
98        inputs = (torch.randn(5, 5, 5),)
99        (
100            Tester(module, inputs)
101            .export()
102            .check_count({"torch.ops.aten.slice.Tensor": 3})
103            .to_edge_transform_and_lower()
104            .check_not(["torch.ops.higher_order.executorch_call_delegate"])
105        )
106
107    def test_fp32_static_slice_with_dynamic_dim(self):
108        """
109        XNNPACK does not support dynamic dims with static slice
110        """
111
112        class SliceCopy(torch.nn.Module):
113            def forward(self, x):
114                return x[1:3, -2:, :-1]
115
116        inputs = (torch.randn(5, 5, 5),)
117        (
118            Tester(
119                SliceCopy(),
120                inputs,
121                dynamic_shapes=({2: torch.export.Dim("dim_2", min=4, max=100)},),
122            )
123            .export()
124            .to_edge_transform_and_lower()
125            .check_not(["torch.ops.higher_order.executorch_call_delegate"])
126        )
127
128    # Note: Slice ends up as slice_copy later in the process, but during quantization,
129    # it's still slice, which isn't supported by the XNNPACK quantizer.
130    @unittest.skip("T156004676 - slice isn't propagated")
131    def _test_qs8_slice_copy(self):
132        class SliceCopy(torch.nn.Module):
133            def forward(self, x):
134                y = x + x
135                z = y[1:3, -2:, :-1]
136                return z
137
138        inputs = (torch.randn(5, 5, 5),)
139        (
140            Tester(SliceCopy(), inputs)
141            .quantize()
142            .export()
143            .check_node_count(
144                {
145                    "aten::slice.Tensor": 3,
146                    "quantized_decomposed::quantize_per_tensor": 3,
147                }
148            )
149            .to_edge_transform_and_lower()
150            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
151            .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"])
152            .to_executorch()
153            .serialize()
154            .run_method_and_compare_outputs()
155        )
156