1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.arm.test import common 14from executorch.backends.arm.test.tester.arm_tester import ArmTester 15from executorch.exir.backend.backend_details import CompileSpec 16from parameterized import parameterized 17 18test_data_suite = [ 19 # (test_name, test_data) 20 ("ones_rank4", torch.ones(1, 10, 10, 10)), 21 ("ones_rank3", torch.ones(10, 10, 10)), 22 ("rand", torch.rand(10, 10) + 0.001), 23 ("randn_pos", torch.randn(10) + 10), 24 ("randn_spread", torch.max(torch.Tensor([0.0]), torch.randn(10) * 100)), 25 ("ramp", torch.arange(0.01, 20, 0.2)), 26] 27 28 29class TestLog(unittest.TestCase): 30 """Tests lowering of aten.log""" 31 32 class Log(torch.nn.Module): 33 def forward(self, x: torch.Tensor) -> torch.Tensor: 34 return torch.log(x) 35 36 def _test_log_tosa_MI_pipeline( 37 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 38 ): 39 ( 40 ArmTester( 41 module, 42 example_inputs=test_data, 43 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 44 ) 45 .export() 46 .check(["torch.ops.aten.log.default"]) 47 .check_not(["torch.ops.quantized_decomposed"]) 48 .to_edge() 49 .partition() 50 .check_not(["executorch_exir_dialects_edge__ops_aten_log_default"]) 51 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 52 .to_executorch() 53 .run_method_and_compare_outputs(inputs=test_data) 54 ) 55 56 def _test_log_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 57 ( 58 ArmTester( 59 module, 60 example_inputs=test_data, 61 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 62 ) 63 .quantize() 64 .export() 65 .check(["torch.ops.aten.log.default"]) 66 .check(["torch.ops.quantized_decomposed"]) 67 .to_edge() 68 .partition() 69 .check_not(["executorch_exir_dialects_edge__ops_aten_log_default"]) 70 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 71 .to_executorch() 72 .run_method_and_compare_outputs(inputs=test_data) 73 ) 74 75 def _test_log_ethosu_BI_pipeline( 76 self, 77 compile_spec: CompileSpec, 78 module: torch.nn.Module, 79 test_data: Tuple[torch.tensor], 80 ): 81 ( 82 ArmTester( 83 module, 84 example_inputs=test_data, 85 compile_spec=compile_spec, 86 ) 87 .quantize() 88 .export() 89 .check_count({"torch.ops.aten.log.default": 1}) 90 .check(["torch.ops.quantized_decomposed"]) 91 .to_edge() 92 .partition() 93 .check_not(["executorch_exir_dialects_edge__ops_aten_log_default"]) 94 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 95 .to_executorch() 96 ) 97 98 @parameterized.expand(test_data_suite) 99 def test_log_tosa_MI( 100 self, 101 test_name: str, 102 test_data: torch.Tensor, 103 ): 104 self._test_log_tosa_MI_pipeline(self.Log(), (test_data,)) 105 106 @parameterized.expand(test_data_suite) 107 def test_log_tosa_BI(self, test_name: str, test_data: torch.Tensor): 108 self._test_log_tosa_BI_pipeline(self.Log(), (test_data,)) 109 110 @parameterized.expand(test_data_suite) 111 def test_log_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): 112 self._test_log_ethosu_BI_pipeline( 113 common.get_u55_compile_spec(), self.Log(), (test_data,) 114 ) 115 116 @parameterized.expand(test_data_suite) 117 def test_log_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor): 118 self._test_log_ethosu_BI_pipeline( 119 common.get_u85_compile_spec(), self.Log(), (test_data,) 120 ) 121