xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_simplify_ops_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4import unittest
5from typing import cast, Optional, Tuple
6
7import executorch.backends.cadence.aot.ops_registrations  # noqa
8import torch
9from executorch.backends.cadence.aot.compiler import export_to_edge
10from executorch.backends.cadence.aot.pass_utils import count_node
11from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
12from executorch.exir.dialects._ops import ops as exir_ops
13from parameterized.parameterized import parameterized
14from torch.fx.passes.infra.pass_base import PassResult
15
16
17class TestSimplifyOpsPasses(unittest.TestCase):
18    @parameterized.expand(
19        [
20            [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
21        ]
22    )
23    @torch.no_grad()
24    def test_simplify_slice_scatter_op(
25        self,
26        in_shape: Tuple[int],
27        src_shape: Tuple[int],
28        dim: int,
29        start: Optional[int] = None,
30        end: Optional[int] = None,
31        step: int = 1,
32    ):
33        class SliceScatter(torch.nn.Module):
34            def __init__(
35                self, dim: int, start: Optional[int], end: Optional[int], step: int
36            ):
37                super().__init__()
38                self.dim = dim
39                self.start = start
40                self.end = end
41                self.step = step
42
43            def forward(self, x: torch.Tensor, y: torch.Tensor):
44                return torch.slice_scatter(
45                    x, y, self.dim, self.start, self.end, self.step
46                )
47
48        model = SliceScatter(dim, start, end, step)
49        x = torch.randn(in_shape)
50        y = torch.randn(src_shape)
51        graph_module = export_to_edge(model, (x, y)).exported_program().graph_module
52
53        p = SimplifySliceOpPass()
54
55        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
56
57        self.assertEqual(
58            count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0
59        )
60
61    @parameterized.expand(
62        [
63            [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
64        ]
65    )
66    @torch.no_grad()
67    def test_simplify_slice_op(
68        self,
69        in_shape: Tuple[int],
70        src_shape: Tuple[int],
71        dim: int,
72        start: Optional[int] = None,
73        end: Optional[int] = None,
74        step: int = 1,
75    ):
76        class SliceCopy(torch.nn.Module):
77            def __init__(
78                self, dim: int, start: Optional[int], end: Optional[int], step: int
79            ):
80                super().__init__()
81                self.dim = dim
82                self.start = start
83                self.end = end
84                self.step = step
85
86            def forward(self, x: torch.Tensor) -> torch.Tensor:
87                return torch.slice_copy(
88                    x, dim=self.dim, start=self.start, end=self.end, step=self.step
89                )
90
91        # Create a model with single slice copy op.
92        model = SliceCopy(dim, start, end, step)
93        x = torch.randn(in_shape)
94        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
95        self.assertEqual(
96            count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
97        )
98
99        p = SimplifySliceOpPass()
100
101        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
102
103        self.assertEqual(
104            count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
105        )
106        self.assertEqual(
107            count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
108        )
109