xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_logsoftmax.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9from typing import Tuple
10
11import torch
12from executorch.backends.arm.test import common
13from executorch.backends.arm.test.tester.arm_tester import ArmTester
14from executorch.exir.backend.compile_spec_schema import CompileSpec
15from parameterized import parameterized
16
17
18test_data_suite = [
19    # (test_name, test_data, dim)
20    ("zeros", torch.zeros(10, 10, 10, 10), 0),
21    ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4),
22    ("ones", torch.ones(10, 10), 1),
23    ("rand_neg_dim", torch.rand(10, 10, 10), -1),
24    ("rand", torch.rand(10, 10, 10, 10), 2),
25    ("rand_neg_dim", torch.rand(10, 10, 2, 3), -2),
26    ("randn", torch.randn(10, 10, 5, 10), 3),
27    ("randn_neg_dim", torch.randn(1, 10, 10, 10), -3),
28]
29
30
31class TestLogSoftmax(unittest.TestCase):
32    """Tests logsoftmax."""
33
34    class LogSoftmax(torch.nn.Module):
35        def __init__(self, dim: int = -1):
36            super().__init__()
37            self.logsoftmax = torch.nn.LogSoftmax(dim=dim)
38
39        def forward(self, x):
40            return self.logsoftmax(x)
41
42    def _test_logsoftmax_tosa_MI_pipeline(
43        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
44    ):
45        (
46            ArmTester(
47                module,
48                example_inputs=test_data,
49                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
50            )
51            .export()
52            .check(["torch.ops.aten.log_softmax.int"])
53            .check_not(["torch.ops.quantized_decomposed"])
54            .to_edge()
55            .partition()
56            .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
57            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
58            .to_executorch()
59            .run_method_and_compare_outputs(inputs=test_data)
60        )
61
62    def _test_logsoftmax_tosa_BI_pipeline(
63        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
64    ):
65        (
66            ArmTester(
67                module,
68                example_inputs=test_data,
69                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
70            )
71            .quantize()
72            .export()
73            .check_not(["torch.ops.aten.log_softmax.int"])
74            .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
75            .to_edge()
76            .partition()
77            .check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"])
78            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
79            .to_executorch()
80            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
81        )
82
83    def _test_logsoftmax_tosa_ethos_BI_pipeline(
84        self,
85        compile_spec: list[CompileSpec],
86        module: torch.nn.Module,
87        test_data: Tuple[torch.tensor],
88    ):
89        (
90            ArmTester(
91                module,
92                example_inputs=test_data,
93                compile_spec=compile_spec,
94            )
95            .quantize()
96            .export()
97            .check_not(["torch.ops.aten.log_softmax.int"])
98            .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"])
99            .to_edge()
100            .partition()
101            .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"])
102            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
103            .to_executorch()
104        )
105
106    def _test_logsoftmax_tosa_u55_BI_pipeline(
107        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
108    ):
109        self._test_logsoftmax_tosa_ethos_BI_pipeline(
110            common.get_u55_compile_spec(), module, test_data
111        )
112
113    def _test_logsoftmax_tosa_u85_BI_pipeline(
114        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
115    ):
116        self._test_logsoftmax_tosa_ethos_BI_pipeline(
117            common.get_u85_compile_spec(), module, test_data
118        )
119
120    @parameterized.expand(test_data_suite)
121    def test_logsoftmax_tosa_MI(
122        self,
123        test_name: str,
124        test_data: torch.Tensor,
125        dim: int,
126    ):
127        self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
128
129    @parameterized.expand(test_data_suite)
130    def test_logsoftmax_tosa_BI(
131        self,
132        test_name: str,
133        test_data: torch.Tensor,
134        dim: int,
135    ):
136        self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,))
137
138    @parameterized.expand(test_data_suite)
139    def test_logsoftmax_tosa_u55_BI(
140        self,
141        test_name: str,
142        test_data: torch.Tensor,
143        dim: int,
144    ):
145        self._test_logsoftmax_tosa_u55_BI_pipeline(
146            self.LogSoftmax(dim=dim), (test_data,)
147        )
148
149    @parameterized.expand(test_data_suite)
150    def test_logsoftmax_tosa_u85_BI(
151        self,
152        test_name: str,
153        test_data: torch.Tensor,
154        dim: int,
155    ):
156        self._test_logsoftmax_tosa_u55_BI_pipeline(
157            self.LogSoftmax(dim=dim), (test_data,)
158        )
159