1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestSliceCopy(unittest.TestCase): 14 def _test_slice_copy(self, module, inputs, copy_count=1, edge_copy_count=1): 15 ( 16 Tester(module, inputs) 17 .export() 18 .check_count({"torch.ops.aten.slice.Tensor": copy_count}) 19 .to_edge() 20 .check_count( 21 { 22 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": edge_copy_count 23 } 24 ) 25 .partition() 26 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 27 .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) 28 .to_executorch() 29 .serialize() 30 .run_method_and_compare_outputs() 31 ) 32 33 def test_fp16_slice_copy(self): 34 class SliceCopy(torch.nn.Module): 35 def forward(self, x): 36 return x[1:3, -2:, :-1] 37 38 inputs = (torch.randn(5, 5, 5).to(torch.float16),) 39 self._test_slice_copy(SliceCopy(), inputs, 3, 3) 40 41 def test_fp32_slice_copy(self): 42 class SliceCopy(torch.nn.Module): 43 def forward(self, x): 44 return x[1:3, -2:, :-1] 45 46 inputs = (torch.randn(5, 5, 5),) 47 self._test_slice_copy(SliceCopy(), inputs, 3, 3) 48 49 def test_fp32_slice_copy_memory_format(self): 50 class ConvSlice(torch.nn.Module): 51 def __init__(self): 52 super().__init__() 53 self.conv = torch.nn.Conv2d( 54 in_channels=1, 55 out_channels=3, 56 kernel_size=(3, 3), 57 padding=1, 58 bias=False, 59 ) 60 61 def forward(self, x): 62 y = self.conv(x) 63 return y[:, :, 2:3, -2:] 64 65 inputs = (torch.randn(1, 1, 3, 3),) 66 # Note that two of the slices are optimized away as they are identity. 67 self._test_slice_copy(ConvSlice(), inputs, 4, 2) 68 69 def test_fp32_slice_copy_stride_non_1(self): 70 """ 71 XNNPACK does not support strided slicing. 72 """ 73 74 class Slice(torch.nn.Module): 75 def forward(self, x): 76 return x[:3:2, :, :] 77 78 module = Slice() 79 inputs = (torch.randn(5, 5, 5),) 80 ( 81 Tester(module, inputs) 82 .export() 83 .check_count({"torch.ops.aten.slice.Tensor": 3}) 84 .to_edge_transform_and_lower() 85 .check_not(["torch.ops.higher_order.executorch_call_delegate"]) 86 ) 87 88 def test_fp32_slice_copy_dim_0(self): 89 """ 90 XNNPACK does not support 0-size dims. 91 """ 92 93 class Slice(torch.nn.Module): 94 def forward(self, x): 95 return x[-1:3, 2:, 3:3] 96 97 module = Slice() 98 inputs = (torch.randn(5, 5, 5),) 99 ( 100 Tester(module, inputs) 101 .export() 102 .check_count({"torch.ops.aten.slice.Tensor": 3}) 103 .to_edge_transform_and_lower() 104 .check_not(["torch.ops.higher_order.executorch_call_delegate"]) 105 ) 106 107 def test_fp32_static_slice_with_dynamic_dim(self): 108 """ 109 XNNPACK does not support dynamic dims with static slice 110 """ 111 112 class SliceCopy(torch.nn.Module): 113 def forward(self, x): 114 return x[1:3, -2:, :-1] 115 116 inputs = (torch.randn(5, 5, 5),) 117 ( 118 Tester( 119 SliceCopy(), 120 inputs, 121 dynamic_shapes=({2: torch.export.Dim("dim_2", min=4, max=100)},), 122 ) 123 .export() 124 .to_edge_transform_and_lower() 125 .check_not(["torch.ops.higher_order.executorch_call_delegate"]) 126 ) 127 128 # Note: Slice ends up as slice_copy later in the process, but during quantization, 129 # it's still slice, which isn't supported by the XNNPACK quantizer. 130 @unittest.skip("T156004676 - slice isn't propagated") 131 def _test_qs8_slice_copy(self): 132 class SliceCopy(torch.nn.Module): 133 def forward(self, x): 134 y = x + x 135 z = y[1:3, -2:, :-1] 136 return z 137 138 inputs = (torch.randn(5, 5, 5),) 139 ( 140 Tester(SliceCopy(), inputs) 141 .quantize() 142 .export() 143 .check_node_count( 144 { 145 "aten::slice.Tensor": 3, 146 "quantized_decomposed::quantize_per_tensor": 3, 147 } 148 ) 149 .to_edge_transform_and_lower() 150 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 151 .check_not(["executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"]) 152 .to_executorch() 153 .serialize() 154 .run_method_and_compare_outputs() 155 ) 156