# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from executorch.backends.xnnpack.test.tester import Tester class TestCat(unittest.TestCase): class Cat2(torch.nn.Module): def forward(self, arg1, arg2): xs = [arg1, arg2] x = torch.cat(xs) return x + x # Quantize by propagation. class Cat3(torch.nn.Module): def forward(self, arg1, arg2, arg3): xs = [arg1, arg2, arg3] x = torch.cat(xs) return x + x # Quantize by propagation. class Cat4(torch.nn.Module): def forward(self, arg1, arg2, arg3, arg4): xs = [arg1, arg2, arg3, arg4] x = torch.cat(xs) return x + x # Quantize by propagation. class Cat5(torch.nn.Module): def forward(self, arg1, arg2, arg3, arg4, arg5): xs = [arg1, arg2, arg3, arg4, arg5] x = torch.cat(xs) return x + x # Quantize by propagation. def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2): for legacy_mode in (True, False): tester = Tester(module, inputs) if quant: tester.quantize() tester.export().check_count({"torch.ops.aten.cat": 1}) tester.dump_artifact() if quant: # Expect multiple quantize ops - one per input, cat, and add. tester.check_node_count( { # Q/DQ pair for each input and quantized op. For most tests, there are # two quantized ops - cat and add. torch.ops.quantized_decomposed.quantize_per_tensor.default: ( cat_num + quant_ops ) } ) if legacy_mode: tester.to_edge() tester.partition() else: tester.to_edge_transform_and_lower() if quant: tester.check_not(["torch.ops.quantized_decomposed"]) ( tester.check_count( {"torch.ops.higher_order.executorch_call_delegate": 1} ) .check_not(["executorch_exir_dialects_edge__ops_aten_cat"]) .to_executorch() .serialize() .run_method_and_compare_outputs() ) def test_fp16_cat2(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( torch.randn(1, 2, 3).to(torch.float16), torch.randn(3, 2, 3).to(torch.float16), ) self._test_cat(self.Cat2(), inputs) def test_fp16_cat3(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( torch.randn(1, 2, 3).to(torch.float16), torch.randn(3, 2, 3).to(torch.float16), torch.randn(2, 2, 3).to(torch.float16), ) self._test_cat(self.Cat3(), inputs) def test_fp16_cat4(self): """ Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first. """ inputs = ( torch.randn(1, 2, 3).to(torch.float16), torch.randn(3, 2, 3).to(torch.float16), torch.randn(2, 2, 3).to(torch.float16), torch.randn(5, 2, 3).to(torch.float16), ) self._test_cat(self.Cat4(), inputs) def test_fp32_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) self._test_cat(self.Cat2(), inputs) def test_fp32_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) self._test_cat(self.Cat3(), inputs) def test_fp32_cat4(self): inputs = ( torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) self._test_cat(self.Cat4(), inputs) def test_qs8_cat2(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3)) self._test_cat(self.Cat2(), inputs, cat_num=2, quant=True) def test_qs8_cat3(self): inputs = (torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3)) self._test_cat(self.Cat3(), inputs, cat_num=3, quant=True) def test_qs8_cat4(self): inputs = ( torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3), torch.randn(5, 2, 3), ) self._test_cat(self.Cat4(), inputs, cat_num=4, quant=True) def test_fp32_cat_unsupported(self): """ XNNPACK only supports concatenating up to 4 values, so it should not delegate here. """ inputs = ( torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), ) ( Tester(self.Cat5(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge_transform_and_lower() .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) ) def test_fp32_cat_unsupported_legacy_mode(self): """ XNNPACK only supports concatenating up to 4 values, so it should not delegate here. """ inputs = ( torch.randn(1, 2, 3), torch.randn(3, 2, 3), torch.randn(2, 2, 3), torch.randn(5, 2, 3), torch.randn(1, 2, 3), ) ( Tester(self.Cat5(), inputs) .export() .check_count({"torch.ops.aten.cat": 1}) .to_edge() .partition() .check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1}) ) class CatNegativeDim(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.cat([x, y], -1) def test_fp32_cat_negative_dim(self): inputs = (torch.randn(3, 2, 3), torch.randn(3, 2, 1)) self._test_cat(self.CatNegativeDim(), inputs) class CatNhwc(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=False, ) def forward(self, x, y): x = self.conv(x) z = torch.concatenate((y, x, y, x), 1) return z + z @unittest.skip("T172862540 - Runtime failure.") def _test_qs8_cat_nhwc(self): inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=3) class CatNhwc2(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=False, ) def forward(self, x, y): x = self.conv(x) y = self.conv(y) z = torch.concatenate((y, x, y, x), 3) return z + z @unittest.skip("T172862540 - Runtime failure.") def _test_qs8_cat_nhwc2(self): inputs = (torch.randn(1, 1, 3, 3), torch.randn(1, 1, 3, 3)) self._test_cat(self.CatNhwc(), inputs, quant=True, quant_ops=4)