xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_mul.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10import torch
11from executorch.backends.arm.test import common
12from executorch.backends.arm.test.tester.arm_tester import ArmTester
13from executorch.exir.backend.backend_details import CompileSpec
14from parameterized import parameterized
15
16test_data_sute = [
17    # (test_name, input, other,) See torch.mul() for info
18    (
19        "op_mul_rank1_ones",
20        torch.ones(5),
21        torch.ones(5),
22    ),
23    (
24        "op_mul_rank2_rand",
25        torch.rand(4, 5),
26        torch.rand(1, 5),
27    ),
28    (
29        "op_mul_rank3_randn",
30        torch.randn(10, 5, 2),
31        torch.randn(10, 5, 2),
32    ),
33    (
34        "op_mul_rank4_randn",
35        torch.randn(5, 10, 25, 20),
36        torch.randn(5, 10, 25, 20),
37    ),
38    (
39        "op_mul_rank4_ones_mul_negative",
40        torch.ones(1, 10, 25, 20),
41        (-1) * torch.ones(5, 10, 25, 20),
42    ),
43    (
44        "op_mul_rank4_negative_large_rand",
45        (-200) * torch.rand(5, 10, 25, 20),
46        torch.rand(5, 1, 1, 20),
47    ),
48    (
49        "op_mul_rank4_large_randn",
50        200 * torch.randn(5, 10, 25, 20),
51        torch.rand(5, 10, 25, 1),
52    ),
53]
54
55
56class TestMul(unittest.TestCase):
57    class Mul(torch.nn.Module):
58
59        def forward(
60            self,
61            input_: torch.Tensor,
62            other_: torch.Tensor,
63        ):
64            return input_ * other_
65
66    def _test_mul_tosa_MI_pipeline(
67        self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
68    ):
69        (
70            ArmTester(
71                module,
72                example_inputs=test_data,
73                compile_spec=common.get_tosa_compile_spec(
74                    "TOSA-0.80.0+MI", permute_memory_to_nhwc=True
75                ),
76            )
77            .export()
78            .check_count({"torch.ops.aten.mul.Tensor": 1})
79            .check_not(["torch.ops.quantized_decomposed"])
80            .to_edge()
81            .partition()
82            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
83            .to_executorch()
84            .run_method_and_compare_outputs(inputs=test_data)
85        )
86
87    def _test_mul_tosa_BI_pipeline(
88        self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor]
89    ):
90        (
91            ArmTester(
92                module,
93                example_inputs=test_data,
94                compile_spec=common.get_tosa_compile_spec(
95                    "TOSA-0.80.0+BI", permute_memory_to_nhwc=True
96                ),
97            )
98            .quantize()
99            .export()
100            .check_count({"torch.ops.aten.mul.Tensor": 1})
101            .check(["torch.ops.quantized_decomposed"])
102            .to_edge()
103            .partition()
104            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
105            .to_executorch()
106            .run_method_and_compare_outputs(inputs=test_data, qtol=1.0)
107        )
108
109    def _test_mul_ethosu_BI_pipeline(
110        self,
111        compile_spec: CompileSpec,
112        module: torch.nn.Module,
113        test_data: tuple[torch.Tensor, torch.Tensor],
114    ):
115        (
116            ArmTester(
117                module,
118                example_inputs=test_data,
119                compile_spec=compile_spec,
120            )
121            .quantize()
122            .export()
123            .check_count({"torch.ops.aten.mul.Tensor": 1})
124            .check(["torch.ops.quantized_decomposed"])
125            .to_edge()
126            .partition()
127            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
128            .to_executorch()
129        )
130
131    @parameterized.expand(test_data_sute)
132    def test_mul_tosa_MI(
133        self,
134        test_name: str,
135        input_: torch.Tensor,
136        other_: torch.Tensor,
137    ):
138        test_data = (input_, other_)
139        self._test_mul_tosa_MI_pipeline(self.Mul(), test_data)
140
141    @parameterized.expand(test_data_sute)
142    def test_mul_tosa_BI(
143        self,
144        test_name: str,
145        input_: torch.Tensor,
146        other_: torch.Tensor,
147    ):
148
149        test_data = (input_, other_)
150        self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)
151
152    @parameterized.expand(test_data_sute)
153    def test_mul_u55_BI(
154        self,
155        test_name: str,
156        input_: torch.Tensor,
157        other_: torch.Tensor,
158    ):
159        test_data = (input_, other_)
160        self._test_mul_ethosu_BI_pipeline(
161            common.get_u55_compile_spec(), self.Mul(), test_data
162        )
163
164    @parameterized.expand(test_data_sute)
165    def test_mul_u85_BI(
166        self,
167        test_name: str,
168        input_: torch.Tensor,
169        other_: torch.Tensor,
170    ):
171        test_data = (input_, other_)
172        self._test_mul_ethosu_BI_pipeline(
173            common.get_u85_compile_spec(), self.Mul(), test_data
174        )
175