xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_div.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import logging
9import unittest
10
11from typing import Optional, Tuple, Union
12
13import torch
14from executorch.backends.arm.test import common
15from executorch.backends.arm.test.tester.arm_tester import ArmTester
16from parameterized import parameterized
17
18logger = logging.getLogger(__name__)
19logger.setLevel(logging.INFO)
20
21test_data_suite = [
22    # (test_name, input, other, rounding_mode) See torch.div() for info
23    (
24        "op_div_rank1_ones",
25        torch.ones(5),
26        torch.ones(5),
27        None,
28    ),
29    (
30        "op_div_rank1_rand",
31        torch.rand(5) * 5,
32        torch.rand(5) * 5,
33        None,
34    ),
35    (
36        "op_div_rank1_negative_ones",
37        torch.ones(5) * (-1),
38        torch.ones(5) * (-1),
39        None,
40    ),
41    (
42        "op_div_rank4_ones",
43        torch.ones(5, 10, 25, 20),
44        torch.ones(5, 10, 25, 20),
45        None,
46    ),
47    (
48        "op_div_rank4_negative_ones",
49        (-1) * torch.ones(5, 10, 25, 20),
50        torch.ones(5, 10, 25, 20),
51        None,
52    ),
53    (
54        "op_div_rank4_ones_div_negative",
55        torch.ones(5, 10, 25, 20),
56        (-1) * torch.ones(5, 10, 25, 20),
57        None,
58    ),
59    (
60        "op_div_rank4_large_rand",
61        200 * torch.rand(5, 10, 25, 20),
62        torch.rand(5, 10, 25, 20),
63        None,
64    ),
65    (
66        "op_div_rank4_negative_large_rand",
67        (-200) * torch.rand(5, 10, 25, 20),
68        torch.rand(5, 10, 25, 20),
69        None,
70    ),
71    (
72        "op_div_rank4_large_randn",
73        200 * torch.randn(5, 10, 25, 20) + 1,
74        torch.rand(5, 10, 25, 20) + 1,
75        None,
76    ),
77]
78
79
80class TestDiv(unittest.TestCase):
81    """Tests division"""
82
83    class Div(torch.nn.Module):
84
85        def forward(
86            self,
87            input_: Union[torch.Tensor, torch.types.Number],
88            other_: Union[torch.Tensor, torch.types.Number],
89            rounding_mode: Optional[str] = None,
90        ):
91            if rounding_mode is None:
92                return torch.div(input=input_, other=other_)
93            else:
94                return torch.div(
95                    input=input_, other=other_, rounding_mode=rounding_mode
96                )
97
98    def _test_div_tosa_MI_pipeline(
99        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
100    ):
101        (
102            ArmTester(
103                module,
104                example_inputs=test_data,
105                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
106            )
107            .export()
108            .check_count({"torch.ops.aten.div.Tensor": 1})
109            .check_not(["torch.ops.quantized_decomposed"])
110            .to_edge()
111            .partition()
112            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
113            .to_executorch()
114            .run_method_and_compare_outputs(inputs=test_data)
115        )
116
117    def _test_div_tosa_BI_pipeline(
118        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
119    ):
120        (
121            ArmTester(
122                module,
123                example_inputs=test_data,
124                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
125            )
126            .quantize()
127            .export()
128            .check_count(
129                {"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1}
130            )
131            .check(["torch.ops.quantized_decomposed"])
132            .to_edge()
133            .partition()
134            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
135            .to_executorch()
136            .run_method_and_compare_outputs(inputs=test_data, atol=1, rtol=0.1)
137        )
138
139    def _test_div_u55_BI_pipeline(
140        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
141    ):
142        (
143            ArmTester(
144                module,
145                example_inputs=test_data,
146                compile_spec=common.get_u55_compile_spec(),
147            )
148            .quantize()
149            .export()
150            .check_count(
151                {"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1}
152            )
153            .check(["torch.ops.quantized_decomposed"])
154            .to_edge()
155            .partition()
156            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
157            .to_executorch()
158        )
159
160    @parameterized.expand(test_data_suite)
161    def test_div_tosa_MI(
162        self,
163        test_name: str,
164        input_: Union[torch.Tensor, torch.types.Number],
165        other_: Union[torch.Tensor, torch.types.Number],
166        rounding_mode: Optional[str] = None,
167    ):
168        test_data = (input_, other_)
169        self._test_div_tosa_MI_pipeline(self.Div(), test_data)
170
171    @parameterized.expand(test_data_suite)
172    def test_div_tosa_BI(
173        self,
174        test_name: str,
175        input_: Union[torch.Tensor, torch.types.Number],
176        other_: Union[torch.Tensor, torch.types.Number],
177        rounding_mode: Optional[str] = None,
178    ):
179
180        test_data = (input_, other_)
181        self._test_div_tosa_BI_pipeline(self.Div(), test_data)
182
183    @parameterized.expand(test_data_suite)
184    def test_div_u55_BI(
185        self,
186        test_name: str,
187        input_: Union[torch.Tensor, torch.types.Number],
188        other_: Union[torch.Tensor, torch.types.Number],
189        rounding_mode: Optional[str] = None,
190    ):
191        test_data = (input_, other_)
192        self._test_div_u55_BI_pipeline(self.Div(), test_data)
193