1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.arm.quantizer.arm_quantizer import ( 14 ArmQuantizer, 15 get_symmetric_quantization_config, 16) 17from executorch.backends.arm.test import common 18from executorch.backends.arm.test.tester.arm_tester import ArmTester 19from executorch.backends.xnnpack.test.tester.tester import Quantize 20from executorch.exir.backend.backend_details import CompileSpec 21from parameterized import parameterized 22 23 24test_data_suite = [ 25 # (test_name, test_data) 26 ("zeros", torch.zeros(1, 10, 10, 10)), 27 ("ones", torch.ones(10, 10, 10)), 28 ("rand", torch.rand(10, 10) - 0.5), 29 ("randn_pos", torch.randn(10) + 10), 30 ("randn_neg", torch.randn(10) - 10), 31 ("ramp", torch.arange(-16, 16, 0.2)), 32] 33 34 35class TestRelu(unittest.TestCase): 36 class Relu(torch.nn.Module): 37 def __init__(self): 38 super().__init__() 39 self.relu = torch.nn.ReLU() 40 41 def forward(self, x): 42 return self.relu(x) 43 44 def _test_relu_tosa_MI_pipeline( 45 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 46 ): 47 ( 48 ArmTester( 49 module, 50 example_inputs=test_data, 51 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 52 ) 53 .export() 54 .check(["torch.ops.aten.relu.default"]) 55 .check_not(["torch.ops.quantized_decomposed"]) 56 .to_edge() 57 .partition() 58 .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) 59 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 60 .to_executorch() 61 .run_method_and_compare_outputs(inputs=test_data) 62 ) 63 64 def _test_relu_tosa_BI_pipeline( 65 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 66 ): 67 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 68 ( 69 ArmTester( 70 module, 71 example_inputs=test_data, 72 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 73 ) 74 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 75 .export() 76 .check_count({"torch.ops.aten.relu.default": 1}) 77 .check(["torch.ops.quantized_decomposed"]) 78 .to_edge() 79 .partition() 80 .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) 81 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 82 .to_executorch() 83 .run_method_and_compare_outputs(inputs=test_data) 84 ) 85 86 def _test_relu_ethosu_BI_pipeline( 87 self, 88 compile_spec: CompileSpec, 89 module: torch.nn.Module, 90 test_data: Tuple[torch.tensor], 91 ): 92 quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) 93 ( 94 ArmTester( 95 module, 96 example_inputs=test_data, 97 compile_spec=compile_spec, 98 ) 99 .quantize(Quantize(quantizer, get_symmetric_quantization_config())) 100 .export() 101 .check_count({"torch.ops.aten.relu.default": 1}) 102 .check(["torch.ops.quantized_decomposed"]) 103 .to_edge() 104 .partition() 105 .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"]) 106 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 107 .to_executorch() 108 ) 109 110 @parameterized.expand(test_data_suite) 111 def test_relu_tosa_MI( 112 self, 113 test_name: str, 114 test_data: torch.Tensor, 115 ): 116 self._test_relu_tosa_MI_pipeline(self.Relu(), (test_data,)) 117 118 @parameterized.expand(test_data_suite) 119 def test_relu_tosa_BI(self, test_name: str, test_data: torch.Tensor): 120 self._test_relu_tosa_BI_pipeline(self.Relu(), (test_data,)) 121 122 @parameterized.expand(test_data_suite) 123 def test_relu_u55_BI(self, test_name: str, test_data: torch.Tensor): 124 self._test_relu_ethosu_BI_pipeline( 125 common.get_u55_compile_spec(), self.Relu(), (test_data,) 126 ) 127 128 @parameterized.expand(test_data_suite) 129 def test_relu_u85_BI(self, test_name: str, test_data: torch.Tensor): 130 self._test_relu_ethosu_BI_pipeline( 131 common.get_u85_compile_spec(), self.Relu(), (test_data,) 132 ) 133