1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.arm.test import common 14from executorch.backends.arm.test.tester.arm_tester import ArmTester 15from executorch.exir.backend.compile_spec_schema import CompileSpec 16from parameterized import parameterized 17 18 19test_data_suite = [ 20 # (test_name, test_data, dim) 21 ("zeros", torch.zeros(10, 8, 5, 2), 0), 22 ("zeros_neg_dim", torch.zeros(10, 7, 8, 9), -4), 23 ("ones", torch.ones(10, 10), 1), 24 ("ones_neg_dim", torch.ones(10, 3, 4), -1), 25 ("rand", torch.rand(1, 2, 5, 8), 2), 26 ("rand_neg_dim", torch.rand(2, 10, 8, 10), -2), 27 ("randn", torch.randn(10, 10, 10, 10), 3), 28 ("randn_neg_dim", torch.randn(10, 5, 8, 7), -3), 29] 30 31 32class TestSoftmax(unittest.TestCase): 33 """Tests softmax.""" 34 35 class Softmax(torch.nn.Module): 36 def __init__(self, dim: int = -1): 37 super().__init__() 38 self.softmax = torch.nn.Softmax(dim=dim) 39 40 def forward(self, x): 41 return self.softmax(x) 42 43 def _test_softmax_tosa_MI_pipeline( 44 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 45 ): 46 ( 47 ArmTester( 48 module, 49 example_inputs=test_data, 50 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 51 ) 52 .export() 53 .check(["torch.ops.aten.softmax.int"]) 54 .check_not(["torch.ops.quantized_decomposed"]) 55 .to_edge() 56 .partition() 57 .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) 58 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 59 .to_executorch() 60 .run_method_and_compare_outputs(inputs=test_data) 61 ) 62 63 def _test_softmax_tosa_BI_pipeline( 64 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 65 ): 66 ( 67 ArmTester( 68 module, 69 example_inputs=test_data, 70 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 71 ) 72 .quantize() 73 .export() 74 .check_not(["torch.ops.aten.softmax.int"]) 75 .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) 76 .to_edge() 77 .partition() 78 .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) 79 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 80 .to_executorch() 81 .run_method_and_compare_outputs(inputs=test_data) 82 ) 83 84 def _test_softmax_tosa_ethos_BI_pipeline( 85 self, 86 compile_spec: list[CompileSpec], 87 module: torch.nn.Module, 88 test_data: Tuple[torch.tensor], 89 ): 90 ( 91 ArmTester( 92 module, 93 example_inputs=test_data, 94 compile_spec=compile_spec, 95 ) 96 .quantize() 97 .export() 98 .check_not(["torch.ops.aten.softmax.int"]) 99 .check(["torch.ops.quantized_decomposed", "torch.ops.aten.mul.Tensor"]) 100 .to_edge() 101 .partition() 102 .check_not(["executorch_exir_dialects_edge__ops_aten__softmax_default"]) 103 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 104 .to_executorch() 105 ) 106 107 def _test_softmax_tosa_u55_BI_pipeline( 108 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 109 ): 110 self._test_softmax_tosa_ethos_BI_pipeline( 111 common.get_u55_compile_spec(), module, test_data 112 ) 113 114 def _test_softmax_tosa_u85_BI_pipeline( 115 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 116 ): 117 self._test_softmax_tosa_ethos_BI_pipeline( 118 common.get_u85_compile_spec(), module, test_data 119 ) 120 121 @parameterized.expand(test_data_suite) 122 def test_softmax_tosa_MI( 123 self, 124 test_name: str, 125 test_data: torch.Tensor, 126 dim: int, 127 ): 128 self._test_softmax_tosa_MI_pipeline(self.Softmax(dim=dim), (test_data,)) 129 130 @parameterized.expand(test_data_suite) 131 def test_softmax_tosa_BI( 132 self, 133 test_name: str, 134 test_data: torch.Tensor, 135 dim: int, 136 ): 137 self._test_softmax_tosa_BI_pipeline(self.Softmax(dim=dim), (test_data,)) 138 139 @parameterized.expand(test_data_suite) 140 def test_softmax_tosa_u55_BI( 141 self, 142 test_name: str, 143 test_data: torch.Tensor, 144 dim: int, 145 ): 146 self._test_softmax_tosa_u55_BI_pipeline(self.Softmax(dim=dim), (test_data,)) 147 148 @parameterized.expand(test_data_suite) 149 def test_softmax_tosa_u85_BI( 150 self, 151 test_name: str, 152 test_data: torch.Tensor, 153 dim: int, 154 ): 155 self._test_softmax_tosa_u85_BI_pipeline(self.Softmax(dim=dim), (test_data,)) 156