xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_softmax.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.test import common
14from executorch.backends.arm.test.tester.arm_tester import ArmTester
15from executorch.exir.backend.compile_spec_schema import CompileSpec
16from parameterized import parameterized
17
18
19test_data_suite = [
20    # (test_name, test_data, dim)
21    ("zeros", torch.zeros(10, 8, 5, 2), 0),
22    ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4),
23    ("ones", torch.ones(10, 10), 1),
24    ("ones_neg_dim", torch.ones(10, 3, 4), -1),
25    ("rand", torch.rand(1, 2, 5, 8), 2),
26    ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2),
27    ("randn", torch.randn(10, 10, 10, 10), 3),
28    ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3),
29]
30
31
32class TestSoftmax(unittest.TestCase):
33    """Tests softmax."""
34
35    class Softmax(torch.nn.Module):
36        def __init__(self, dim: int = -1):
37            super().__init__()
38            self.softmax = torch.nn.Softmax(dim=dim)
39
40        def forward(self, x):
41            return self.softmax(x)
42
43    def _test_softmax_tosa_MI_pipeline(
44        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
45    ):
46        (
47            ArmTester(
48                module,
49                example_inputs=test_data,
50                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
51            )
52            .export()
53            .check(["torch.ops.aten.softmax.int"])
54            .check_not(["torch.ops.quantized_decomposed"])
55            .to_edge()
56            .partition()
57            .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"])
58            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
59            .to_executorch()
60            .run_method_and_compare_outputs(inputs=test_data)
61        )
62
63    def _test_softmax_tosa_BI_pipeline(
64        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
65    ):
66        (
67            ArmTester(
68                module,
69                example_inputs=test_data,
70                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
71            )
72            .quantize()
73            .export()
74            .check_not(["torch.ops.aten.softmax.int"])
75            .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
76            .to_edge()
77            .partition()
78            .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"])
79            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
80            .to_executorch()
81            .run_method_and_compare_outputs(inputs=test_data)
82        )
83
84    def _test_softmax_tosa_ethos_BI_pipeline(
85        self,
86        compile_spec: list[CompileSpec],
87        module: torch.nn.Module,
88        test_data: Tuple[torch.tensor],
89    ):
90        (
91            ArmTester(
92                module,
93                example_inputs=test_data,
94                compile_spec=compile_spec,
95            )
96            .quantize()
97            .export()
98            .check_not(["torch.ops.aten.softmax.int"])
99            .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
100            .to_edge()
101            .partition()
102            .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"])
103            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
104            .to_executorch()
105        )
106
107    def _test_softmax_tosa_u55_BI_pipeline(
108        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
109    ):
110        self._test_softmax_tosa_ethos_BI_pipeline(
111            common.get_u55_compile_spec(), module, test_data
112        )
113
114    def _test_softmax_tosa_u85_BI_pipeline(
115        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
116    ):
117        self._test_softmax_tosa_ethos_BI_pipeline(
118            common.get_u85_compile_spec(), module, test_data
119        )
120
121    @parameterized.expand(test_data_suite)
122    def test_softmax_tosa_MI(
123        self,
124        test_name: str,
125        test_data: torch.Tensor,
126        dim: int,
127    ):
128        self._test_softmax_tosa_MI_pipeline(self.Softmax(dim=dim), (test_data,))
129
130    @parameterized.expand(test_data_suite)
131    def test_softmax_tosa_BI(
132        self,
133        test_name: str,
134        test_data: torch.Tensor,
135        dim: int,
136    ):
137        self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,))
138
139    @parameterized.expand(test_data_suite)
140    def test_softmax_tosa_u55_BI(
141        self,
142        test_name: str,
143        test_data: torch.Tensor,
144        dim: int,
145    ):
146        self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,))
147
148    @parameterized.expand(test_data_suite)
149    def test_softmax_tosa_u85_BI(
150        self,
151        test_name: str,
152        test_data: torch.Tensor,
153        dim: int,
154    ):
155        self._test_softmax_tosa_u85_BI_pipeline(self.Softmax(dim=dim), (test_data,))
156