xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_sum.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9from typing import Tuple
10
11import torch
12from executorch.backends.arm.test import common
13from executorch.backends.arm.test.tester.arm_tester import ArmTester
14from executorch.exir import EdgeCompileConfig
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from parameterized import parameterized
17
18exampledata_t = Tuple[torch.Tensor, int | list[int], bool]
19"""(data, dim(s), keepdim)"""
20
21
22class TestSum(unittest.TestCase):
23    """Tests sum which sums all elements along some specified dimensions.
24    keepdim specifies whether the dimension that is summed should
25    be squeezed or not.
26    """
27
28    class Sum(torch.nn.Module):
29        test_parameters: list[Tuple[exampledata_t]] = [
30            ((torch.rand(10), 0, True),),
31            ((torch.rand(10, 10), 1, False),),
32            ((torch.rand(10, 10, 10), [-3, 1], True),),
33            ((torch.rand(2, 1, 5, 8), 1, False),),
34            ((torch.rand(1, 2, 3, 4), 3, True),),
35            ((torch.rand(1, 2, 8, 8), [2, 3, 0], True),),
36        ]
37
38        def forward(self, x: torch.Tensor, dim: int, keepdim: bool):
39            return x.sum(dim=dim, keepdim=keepdim)
40
41    _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
42        _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
43    )
44
45    def _test_sum_tosa_MI_pipeline(
46        self, module: torch.nn.Module, test_data: tuple[exampledata_t]
47    ):
48        (
49            ArmTester(
50                module,
51                example_inputs=test_data,
52                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
53            )
54            .export()
55            .check_count({"torch.ops.aten.sum.dim_IntList": 1})
56            .check_not(["torch.ops.quantized_decomposed"])
57            .to_edge(config=self._edge_compile_config)
58            .partition()
59            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
60            .to_executorch()
61            .run_method_and_compare_outputs(inputs=test_data)
62        )
63
64    def _test_sum_tosa_BI_pipeline(
65        self, module: torch.nn.Module, test_data: tuple[exampledata_t]
66    ):
67        (
68            ArmTester(
69                module,
70                example_inputs=test_data,
71                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
72            )
73            .quantize()
74            .export()
75            .check_count({"torch.ops.aten.sum.dim_IntList": 1})
76            .check(["torch.ops.quantized_decomposed"])
77            .to_edge(config=self._edge_compile_config)
78            .partition()
79            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
80            .to_executorch()
81            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
82        )
83
84    def _test_sum_ethosu_BI_pipeline(
85        self,
86        module: torch.nn.Module,
87        test_data: tuple[exampledata_t],
88        compile_spec: CompileSpec,
89    ):
90        (
91            ArmTester(
92                module,
93                example_inputs=test_data,
94                compile_spec=compile_spec,
95            )
96            .quantize()
97            .export()
98            .check_count({"torch.ops.aten.sum.dim_IntList": 1})
99            .check(["torch.ops.quantized_decomposed"])
100            .to_edge()
101            .partition()
102            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
103            .to_executorch()
104            .serialize()
105        )
106
107    @parameterized.expand(Sum.test_parameters)
108    def test_sum_tosa_MI(self, test_data: tuple[exampledata_t]):
109        self._test_sum_tosa_MI_pipeline(self.Sum(), test_data)
110
111    @parameterized.expand(Sum.test_parameters)
112    def test_sum_tosa_BI(self, test_data: tuple[exampledata_t]):
113        self._test_sum_tosa_BI_pipeline(self.Sum(), test_data)
114
115    @parameterized.expand(Sum.test_parameters)
116    def test_sum_u55_BI(self, test_data: tuple[exampledata_t]):
117        self._test_sum_ethosu_BI_pipeline(
118            self.Sum(),
119            test_data,
120            common.get_u55_compile_spec(permute_memory_to_nhwc=False),
121        )
122
123    @parameterized.expand(Sum.test_parameters)
124    def test_sum_u85_BI(self, test_data: tuple[exampledata_t]):
125        self._test_sum_ethosu_BI_pipeline(
126            self.Sum(),
127            test_data,
128            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
129        )
130