1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9from typing import Tuple 10 11import torch 12from executorch.backends.arm.test import common 13from executorch.backends.arm.test.tester.arm_tester import ArmTester 14from executorch.exir.backend.compile_spec_schema import CompileSpec 15from parameterized import parameterized 16 17 18test_data_suite = [ 19 # (test_name, test_data, dim) 20 ("zeros", torch.zeros(10, 10, 10, 10), 0), 21 ("zeros_neg_dim", torch.zeros(10, 10, 10, 10), -4), 22 ("ones", torch.ones(10, 10), 1), 23 ("rand_neg_dim", torch.rand(10, 10, 10), -1), 24 ("rand", torch.rand(10, 10, 10, 10), 2), 25 ("rand_neg_dim", torch.rand(10, 10, 2, 3), -2), 26 ("randn", torch.randn(10, 10, 5, 10), 3), 27 ("randn_neg_dim", torch.randn(1, 10, 10, 10), -3), 28] 29 30 31class TestLogSoftmax(unittest.TestCase): 32 """Tests logsoftmax.""" 33 34 class LogSoftmax(torch.nn.Module): 35 def __init__(self, dim: int = -1): 36 super().__init__() 37 self.logsoftmax = torch.nn.LogSoftmax(dim=dim) 38 39 def forward(self, x): 40 return self.logsoftmax(x) 41 42 def _test_logsoftmax_tosa_MI_pipeline( 43 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 44 ): 45 ( 46 ArmTester( 47 module, 48 example_inputs=test_data, 49 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 50 ) 51 .export() 52 .check(["torch.ops.aten.log_softmax.int"]) 53 .check_not(["torch.ops.quantized_decomposed"]) 54 .to_edge() 55 .partition() 56 .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) 57 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 58 .to_executorch() 59 .run_method_and_compare_outputs(inputs=test_data) 60 ) 61 62 def _test_logsoftmax_tosa_BI_pipeline( 63 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 64 ): 65 ( 66 ArmTester( 67 module, 68 example_inputs=test_data, 69 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 70 ) 71 .quantize() 72 .export() 73 .check_not(["torch.ops.aten.log_softmax.int"]) 74 .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) 75 .to_edge() 76 .partition() 77 .check_not(["executorch_exir_dialects_edge__ops_aten__log_softmax_default"]) 78 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 79 .to_executorch() 80 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 81 ) 82 83 def _test_logsoftmax_tosa_ethos_BI_pipeline( 84 self, 85 compile_spec: list[CompileSpec], 86 module: torch.nn.Module, 87 test_data: Tuple[torch.tensor], 88 ): 89 ( 90 ArmTester( 91 module, 92 example_inputs=test_data, 93 compile_spec=compile_spec, 94 ) 95 .quantize() 96 .export() 97 .check_not(["torch.ops.aten.log_softmax.int"]) 98 .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) 99 .to_edge() 100 .partition() 101 .check_not(["executorch_exir_dialects_edge__ops_aten__logsoftmax_default"]) 102 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 103 .to_executorch() 104 ) 105 106 def _test_logsoftmax_tosa_u55_BI_pipeline( 107 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 108 ): 109 self._test_logsoftmax_tosa_ethos_BI_pipeline( 110 common.get_u55_compile_spec(), module, test_data 111 ) 112 113 def _test_logsoftmax_tosa_u85_BI_pipeline( 114 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 115 ): 116 self._test_logsoftmax_tosa_ethos_BI_pipeline( 117 common.get_u85_compile_spec(), module, test_data 118 ) 119 120 @parameterized.expand(test_data_suite) 121 def test_logsoftmax_tosa_MI( 122 self, 123 test_name: str, 124 test_data: torch.Tensor, 125 dim: int, 126 ): 127 self._test_logsoftmax_tosa_MI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) 128 129 @parameterized.expand(test_data_suite) 130 def test_logsoftmax_tosa_BI( 131 self, 132 test_name: str, 133 test_data: torch.Tensor, 134 dim: int, 135 ): 136 self._test_logsoftmax_tosa_BI_pipeline(self.LogSoftmax(dim=dim), (test_data,)) 137 138 @parameterized.expand(test_data_suite) 139 def test_logsoftmax_tosa_u55_BI( 140 self, 141 test_name: str, 142 test_data: torch.Tensor, 143 dim: int, 144 ): 145 self._test_logsoftmax_tosa_u55_BI_pipeline( 146 self.LogSoftmax(dim=dim), (test_data,) 147 ) 148 149 @parameterized.expand(test_data_suite) 150 def test_logsoftmax_tosa_u85_BI( 151 self, 152 test_name: str, 153 test_data: torch.Tensor, 154 dim: int, 155 ): 156 self._test_logsoftmax_tosa_u55_BI_pipeline( 157 self.LogSoftmax(dim=dim), (test_data,) 158 ) 159