1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9from typing import Tuple 10 11import torch 12from executorch.backends.arm.test import common 13from executorch.backends.arm.test.tester.arm_tester import ArmTester 14from executorch.exir import EdgeCompileConfig 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from parameterized import parameterized 17 18exampledata_t = Tuple[torch.Tensor, int | list[int], bool] 19"""(data, dim(s), keepdim)""" 20 21 22class TestSum(unittest.TestCase): 23 """Tests sum which sums all elements along some specified dimensions. 24 keepdim specifies whether the dimension that is summed should 25 be squeezed or not. 26 """ 27 28 class Sum(torch.nn.Module): 29 test_parameters: list[Tuple[exampledata_t]] = [ 30 ((torch.rand(10), 0, True),), 31 ((torch.rand(10, 10), 1, False),), 32 ((torch.rand(10, 10, 10), [-3, 1], True),), 33 ((torch.rand(2, 1, 5, 8), 1, False),), 34 ((torch.rand(1, 2, 3, 4), 3, True),), 35 ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),), 36 ] 37 38 def forward(self, x: torch.Tensor, dim: int, keepdim: bool): 39 return x.sum(dim=dim, keepdim=keepdim) 40 41 _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( 42 _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. 43 ) 44 45 def _test_sum_tosa_MI_pipeline( 46 self, module: torch.nn.Module, test_data: tuple[exampledata_t] 47 ): 48 ( 49 ArmTester( 50 module, 51 example_inputs=test_data, 52 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 53 ) 54 .export() 55 .check_count({"torch.ops.aten.sum.dim_IntList": 1}) 56 .check_not(["torch.ops.quantized_decomposed"]) 57 .to_edge(config=self._edge_compile_config) 58 .partition() 59 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 60 .to_executorch() 61 .run_method_and_compare_outputs(inputs=test_data) 62 ) 63 64 def _test_sum_tosa_BI_pipeline( 65 self, module: torch.nn.Module, test_data: tuple[exampledata_t] 66 ): 67 ( 68 ArmTester( 69 module, 70 example_inputs=test_data, 71 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 72 ) 73 .quantize() 74 .export() 75 .check_count({"torch.ops.aten.sum.dim_IntList": 1}) 76 .check(["torch.ops.quantized_decomposed"]) 77 .to_edge(config=self._edge_compile_config) 78 .partition() 79 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 80 .to_executorch() 81 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 82 ) 83 84 def _test_sum_ethosu_BI_pipeline( 85 self, 86 module: torch.nn.Module, 87 test_data: tuple[exampledata_t], 88 compile_spec: CompileSpec, 89 ): 90 ( 91 ArmTester( 92 module, 93 example_inputs=test_data, 94 compile_spec=compile_spec, 95 ) 96 .quantize() 97 .export() 98 .check_count({"torch.ops.aten.sum.dim_IntList": 1}) 99 .check(["torch.ops.quantized_decomposed"]) 100 .to_edge() 101 .partition() 102 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 103 .to_executorch() 104 .serialize() 105 ) 106 107 @parameterized.expand(Sum.test_parameters) 108 def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]): 109 self._test_sum_tosa_MI_pipeline(self.Sum(), test_data) 110 111 @parameterized.expand(Sum.test_parameters) 112 def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]): 113 self._test_sum_tosa_BI_pipeline(self.Sum(), test_data) 114 115 @parameterized.expand(Sum.test_parameters) 116 def test_sum_u55_BI(self, test_data: tuple[exampledata_t]): 117 self._test_sum_ethosu_BI_pipeline( 118 self.Sum(), 119 test_data, 120 common.get_u55_compile_spec(permute_memory_to_nhwc=False), 121 ) 122 123 @parameterized.expand(Sum.test_parameters) 124 def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): 125 self._test_sum_ethosu_BI_pipeline( 126 self.Sum(), 127 test_data, 128 common.get_u85_compile_spec(permute_memory_to_nhwc=True), 129 ) 130