xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_add.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.test import common
14from executorch.backends.arm.test.tester.arm_tester import ArmTester
15from executorch.exir import EdgeCompileConfig
16from executorch.exir.backend.compile_spec_schema import CompileSpec
17from parameterized import parameterized
18
19
20class TestSimpleAdd(unittest.TestCase):
21    """Tests a single add op, x+x and x+y."""
22
23    class Add(torch.nn.Module):
24        test_parameters = [
25            (torch.FloatTensor([1, 2, 3, 5, 7]),),
26            (3 * torch.ones(8),),
27            (10 * torch.randn(8),),
28            (torch.ones(1, 1, 4, 4),),
29            (torch.ones(1, 3, 4, 2),),
30        ]
31
32        def forward(self, x):
33            return x + x
34
35    class Add2(torch.nn.Module):
36        test_parameters = [
37            (
38                torch.FloatTensor([1, 2, 3, 5, 7]),
39                (torch.FloatTensor([2, 1, 2, 1, 10])),
40            ),
41            (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)),
42            (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
43            (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
44            (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)),
45        ]
46
47        def __init__(self):
48            super().__init__()
49
50        def forward(self, x, y):
51            return x + y
52
53    _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
54        _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
55    )
56
57    def _test_add_tosa_MI_pipeline(
58        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
59    ):
60        (
61            ArmTester(
62                module,
63                example_inputs=test_data,
64                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
65            )
66            .export()
67            .check_count({"torch.ops.aten.add.Tensor": 1})
68            .check_not(["torch.ops.quantized_decomposed"])
69            .to_edge(config=self._edge_compile_config)
70            .partition()
71            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
72            .to_executorch()
73            .run_method_and_compare_outputs(inputs=test_data)
74        )
75
76    def _test_add_tosa_BI_pipeline(
77        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
78    ):
79        (
80            ArmTester(
81                module,
82                example_inputs=test_data,
83                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
84            )
85            .quantize()
86            .export()
87            .check_count({"torch.ops.aten.add.Tensor": 1})
88            .check(["torch.ops.quantized_decomposed"])
89            .to_edge(config=self._edge_compile_config)
90            .partition()
91            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
92            .to_executorch()
93            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
94        )
95
96    def _test_add_ethos_BI_pipeline(
97        self,
98        module: torch.nn.Module,
99        compile_spec: CompileSpec,
100        test_data: Tuple[torch.Tensor],
101    ):
102        tester = (
103            ArmTester(
104                module,
105                example_inputs=test_data,
106                compile_spec=compile_spec,
107            )
108            .quantize()
109            .export()
110            .check_count({"torch.ops.aten.add.Tensor": 1})
111            .check(["torch.ops.quantized_decomposed"])
112            .to_edge()
113            .partition()
114            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
115            .to_executorch()
116            .serialize()
117        )
118
119        return tester
120
121    @parameterized.expand(Add.test_parameters)
122    def test_add_tosa_MI(self, test_data: torch.Tensor):
123        test_data = (test_data,)
124        self._test_add_tosa_MI_pipeline(self.Add(), test_data)
125
126    @parameterized.expand(Add.test_parameters)
127    def test_add_tosa_BI(self, test_data: torch.Tensor):
128        test_data = (test_data,)
129        self._test_add_tosa_BI_pipeline(self.Add(), test_data)
130
131    @parameterized.expand(Add.test_parameters)
132    def test_add_u55_BI(self, test_data: torch.Tensor):
133        test_data = (test_data,)
134        tester = self._test_add_ethos_BI_pipeline(
135            self.Add(),
136            common.get_u55_compile_spec(permute_memory_to_nhwc=True),
137            test_data,
138        )
139        if common.is_option_enabled("corstone300"):
140            tester.run_method_and_compare_outputs(
141                qtol=1, inputs=test_data, target_board="corstone-300"
142            )
143
144    @parameterized.expand(Add.test_parameters)
145    def test_add_u85_BI(self, test_data: torch.Tensor):
146        test_data = (test_data,)
147        tester = self._test_add_ethos_BI_pipeline(
148            self.Add(),
149            common.get_u85_compile_spec(permute_memory_to_nhwc=True),
150            test_data,
151        )
152        if common.is_option_enabled("corstone300"):
153            tester.run_method_and_compare_outputs(
154                qtol=1, inputs=test_data, target_board="corstone-320"
155            )
156
157    @parameterized.expand(Add2.test_parameters)
158    def test_add2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
159        test_data = (operand1, operand2)
160        self._test_add_tosa_MI_pipeline(self.Add2(), test_data)
161
162    @parameterized.expand(Add2.test_parameters)
163    def test_add2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
164        test_data = (operand1, operand2)
165        self._test_add_tosa_BI_pipeline(self.Add2(), test_data)
166
167    @parameterized.expand(Add2.test_parameters)
168    def test_add2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
169        test_data = (operand1, operand2)
170        tester = self._test_add_ethos_BI_pipeline(
171            self.Add2(), common.get_u55_compile_spec(), test_data
172        )
173        if common.is_option_enabled("corstone300"):
174            tester.run_method_and_compare_outputs(
175                qtol=1, inputs=test_data, target_board="corstone-300"
176            )
177
178    @parameterized.expand(Add2.test_parameters)
179    def test_add2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
180        test_data = (operand1, operand2)
181        tester = self._test_add_ethos_BI_pipeline(
182            self.Add2(), common.get_u85_compile_spec(), test_data
183        )
184        if common.is_option_enabled("corstone300"):
185            tester.run_method_and_compare_outputs(
186                qtol=1, inputs=test_data, target_board="corstone-320"
187            )
188