1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestCat(unittest.TestCase): 14 class Cat2(torch.nn.Module): 15 def forward(self, arg1, arg2): 16 xs = [arg1, arg2] 17 x = torch.cat(xs) 18 return x + x # Quantize by propagation. 19 20 class Cat3(torch.nn.Module): 21 def forward(self, arg1, arg2, arg3): 22 xs = [arg1, arg2, arg3] 23 x = torch.cat(xs) 24 return x + x # Quantize by propagation. 25 26 class Cat4(torch.nn.Module): 27 def forward(self, arg1, arg2, arg3, arg4): 28 xs = [arg1, arg2, arg3, arg4] 29 x = torch.cat(xs) 30 return x + x # Quantize by propagation. 31 32 class Cat5(torch.nn.Module): 33 def forward(self, arg1, arg2, arg3, arg4, arg5): 34 xs = [arg1, arg2, arg3, arg4, arg5] 35 x = torch.cat(xs) 36 return x + x # Quantize by propagation. 37 38 def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): 39 for legacy_mode in (True, False): 40 tester = Tester(module, inputs) 41 42 if quant: 43 tester.quantize() 44 45 tester.export().check_count({"torch.ops.aten.cat": 1}) 46 tester.dump_artifact() 47 48 if quant: 49 # Expect multiple quantize ops - one per input, cat, and add. 50 tester.check_node_count( 51 { 52 # Q/DQ pair for each input and quantized op. For most tests, there are 53 # two quantized ops - cat and add. 54 torch.ops.quantized_decomposed.quantize_per_tensor.default: ( 55 cat_num + quant_ops 56 ) 57 } 58 ) 59 60 if legacy_mode: 61 tester.to_edge() 62 tester.partition() 63 else: 64 tester.to_edge_transform_and_lower() 65 66 if quant: 67 tester.check_not(["torch.ops.quantized_decomposed"]) 68 69 ( 70 tester.check_count( 71 {"torch.ops.higher_order.executorch_call_delegate": 1} 72 ) 73 .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) 74 .to_executorch() 75 .serialize() 76 .run_method_and_compare_outputs() 77 ) 78 79 def test_fp16_cat2(self): 80 """ 81 Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. 82 """ 83 inputs = ( 84 torch.randn(1, 2, 3).to(torch.float16), 85 torch.randn(3, 2, 3).to(torch.float16), 86 ) 87 self._test_cat(self.Cat2(), inputs) 88 89 def test_fp16_cat3(self): 90 """ 91 Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. 92 """ 93 inputs = ( 94 torch.randn(1, 2, 3).to(torch.float16), 95 torch.randn(3, 2, 3).to(torch.float16), 96 torch.randn(2, 2, 3).to(torch.float16), 97 ) 98 self._test_cat(self.Cat3(), inputs) 99 100 def test_fp16_cat4(self): 101 """ 102 Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. 103 """ 104 inputs = ( 105 torch.randn(1, 2, 3).to(torch.float16), 106 torch.randn(3, 2, 3).to(torch.float16), 107 torch.randn(2, 2, 3).to(torch.float16), 108 torch.randn(5, 2, 3).to(torch.float16), 109 ) 110 self._test_cat(self.Cat4(), inputs) 111 112 def test_fp32_cat2(self): 113 inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) 114 self._test_cat(self.Cat2(), inputs) 115 116 def test_fp32_cat3(self): 117 inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) 118 self._test_cat(self.Cat3(), inputs) 119 120 def test_fp32_cat4(self): 121 inputs = ( 122 torch.randn(1, 2, 3), 123 torch.randn(3, 2, 3), 124 torch.randn(2, 2, 3), 125 torch.randn(5, 2, 3), 126 ) 127 self._test_cat(self.Cat4(), inputs) 128 129 def test_qs8_cat2(self): 130 inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) 131 self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) 132 133 def test_qs8_cat3(self): 134 inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) 135 self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) 136 137 def test_qs8_cat4(self): 138 inputs = ( 139 torch.randn(1, 2, 3), 140 torch.randn(3, 2, 3), 141 torch.randn(2, 2, 3), 142 torch.randn(5, 2, 3), 143 ) 144 self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) 145 146 def test_fp32_cat_unsupported(self): 147 """ 148 XNNPACK only supports concatenating up to 4 values, so it should not delegate here. 149 """ 150 inputs = ( 151 torch.randn(1, 2, 3), 152 torch.randn(3, 2, 3), 153 torch.randn(2, 2, 3), 154 torch.randn(5, 2, 3), 155 torch.randn(1, 2, 3), 156 ) 157 ( 158 Tester(self.Cat5(), inputs) 159 .export() 160 .check_count({"torch.ops.aten.cat": 1}) 161 .to_edge_transform_and_lower() 162 .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) 163 ) 164 165 def test_fp32_cat_unsupported_legacy_mode(self): 166 """ 167 XNNPACK only supports concatenating up to 4 values, so it should not delegate here. 168 """ 169 inputs = ( 170 torch.randn(1, 2, 3), 171 torch.randn(3, 2, 3), 172 torch.randn(2, 2, 3), 173 torch.randn(5, 2, 3), 174 torch.randn(1, 2, 3), 175 ) 176 ( 177 Tester(self.Cat5(), inputs) 178 .export() 179 .check_count({"torch.ops.aten.cat": 1}) 180 .to_edge() 181 .partition() 182 .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) 183 ) 184 185 class CatNegativeDim(torch.nn.Module): 186 def __init__(self): 187 super().__init__() 188 189 def forward(self, x, y): 190 return torch.cat([x, y], -1) 191 192 def test_fp32_cat_negative_dim(self): 193 inputs = (torch.randn(3, 2, 3), torch.randn(3, 2, 1)) 194 self._test_cat(self.CatNegativeDim(), inputs) 195 196 class CatNhwc(torch.nn.Module): 197 def __init__(self): 198 super().__init__() 199 self.conv = torch.nn.Conv2d( 200 in_channels=1, 201 out_channels=3, 202 kernel_size=(3, 3), 203 padding=1, 204 bias=False, 205 ) 206 207 def forward(self, x, y): 208 x = self.conv(x) 209 z = torch.concatenate((y, x, y, x), 1) 210 return z + z 211 212 @unittest.skip("T172862540 - Runtime failure.") 213 def _test_qs8_cat_nhwc(self): 214 inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) 215 self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=3) 216 217 class CatNhwc2(torch.nn.Module): 218 def __init__(self): 219 super().__init__() 220 self.conv = torch.nn.Conv2d( 221 in_channels=1, 222 out_channels=3, 223 kernel_size=(3, 3), 224 padding=1, 225 bias=False, 226 ) 227 228 def forward(self, x, y): 229 x = self.conv(x) 230 y = self.conv(y) 231 z = torch.concatenate((y, x, y, x), 3) 232 return z + z 233 234 @unittest.skip("T172862540 - Runtime failure.") 235 def _test_qs8_cat_nhwc2(self): 236 inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) 237 self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4) 238