1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4import unittest 5from typing import cast, Optional, Tuple 6 7import executorch.backends.cadence.aot.ops_registrations # noqa 8import torch 9from executorch.backends.cadence.aot.compiler import export_to_edge 10from executorch.backends.cadence.aot.pass_utils import count_node 11from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass 12from executorch.exir.dialects._ops import ops as exir_ops 13from parameterized.parameterized import parameterized 14from torch.fx.passes.infra.pass_base import PassResult 15 16 17class TestSimplifyOpsPasses(unittest.TestCase): 18 @parameterized.expand( 19 [ 20 [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], 21 ] 22 ) 23 @torch.no_grad() 24 def test_simplify_slice_scatter_op( 25 self, 26 in_shape: Tuple[int], 27 src_shape: Tuple[int], 28 dim: int, 29 start: Optional[int] = None, 30 end: Optional[int] = None, 31 step: int = 1, 32 ): 33 class SliceScatter(torch.nn.Module): 34 def __init__( 35 self, dim: int, start: Optional[int], end: Optional[int], step: int 36 ): 37 super().__init__() 38 self.dim = dim 39 self.start = start 40 self.end = end 41 self.step = step 42 43 def forward(self, x: torch.Tensor, y: torch.Tensor): 44 return torch.slice_scatter( 45 x, y, self.dim, self.start, self.end, self.step 46 ) 47 48 model = SliceScatter(dim, start, end, step) 49 x = torch.randn(in_shape) 50 y = torch.randn(src_shape) 51 graph_module = export_to_edge(model, (x, y)).exported_program().graph_module 52 53 p = SimplifySliceOpPass() 54 55 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 56 57 self.assertEqual( 58 count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0 59 ) 60 61 @parameterized.expand( 62 [ 63 [(3, 16, 5), (3, 0, 5), 1, 15, 3, 3], 64 ] 65 ) 66 @torch.no_grad() 67 def test_simplify_slice_op( 68 self, 69 in_shape: Tuple[int], 70 src_shape: Tuple[int], 71 dim: int, 72 start: Optional[int] = None, 73 end: Optional[int] = None, 74 step: int = 1, 75 ): 76 class SliceCopy(torch.nn.Module): 77 def __init__( 78 self, dim: int, start: Optional[int], end: Optional[int], step: int 79 ): 80 super().__init__() 81 self.dim = dim 82 self.start = start 83 self.end = end 84 self.step = step 85 86 def forward(self, x: torch.Tensor) -> torch.Tensor: 87 return torch.slice_copy( 88 x, dim=self.dim, start=self.start, end=self.end, step=self.step 89 ) 90 91 # Create a model with single slice copy op. 92 model = SliceCopy(dim, start, end, step) 93 x = torch.randn(in_shape) 94 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 95 self.assertEqual( 96 count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1 97 ) 98 99 p = SimplifySliceOpPass() 100 101 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 102 103 self.assertEqual( 104 count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 105 ) 106 self.assertEqual( 107 count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1 108 ) 109