1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import logging 9import unittest 10 11from typing import Optional, Tuple, Union 12 13import torch 14from executorch.backends.arm.test import common 15from executorch.backends.arm.test.tester.arm_tester import ArmTester 16from parameterized import parameterized 17 18logger = logging.getLogger(__name__) 19logger.setLevel(logging.INFO) 20 21test_data_suite = [ 22 # (test_name, input, other, rounding_mode) See torch.div() for info 23 ( 24 "op_div_rank1_ones", 25 torch.ones(5), 26 torch.ones(5), 27 None, 28 ), 29 ( 30 "op_div_rank1_rand", 31 torch.rand(5) * 5, 32 torch.rand(5) * 5, 33 None, 34 ), 35 ( 36 "op_div_rank1_negative_ones", 37 torch.ones(5) * (-1), 38 torch.ones(5) * (-1), 39 None, 40 ), 41 ( 42 "op_div_rank4_ones", 43 torch.ones(5, 10, 25, 20), 44 torch.ones(5, 10, 25, 20), 45 None, 46 ), 47 ( 48 "op_div_rank4_negative_ones", 49 (-1) * torch.ones(5, 10, 25, 20), 50 torch.ones(5, 10, 25, 20), 51 None, 52 ), 53 ( 54 "op_div_rank4_ones_div_negative", 55 torch.ones(5, 10, 25, 20), 56 (-1) * torch.ones(5, 10, 25, 20), 57 None, 58 ), 59 ( 60 "op_div_rank4_large_rand", 61 200 * torch.rand(5, 10, 25, 20), 62 torch.rand(5, 10, 25, 20), 63 None, 64 ), 65 ( 66 "op_div_rank4_negative_large_rand", 67 (-200) * torch.rand(5, 10, 25, 20), 68 torch.rand(5, 10, 25, 20), 69 None, 70 ), 71 ( 72 "op_div_rank4_large_randn", 73 200 * torch.randn(5, 10, 25, 20) + 1, 74 torch.rand(5, 10, 25, 20) + 1, 75 None, 76 ), 77] 78 79 80class TestDiv(unittest.TestCase): 81 """Tests division""" 82 83 class Div(torch.nn.Module): 84 85 def forward( 86 self, 87 input_: Union[torch.Tensor, torch.types.Number], 88 other_: Union[torch.Tensor, torch.types.Number], 89 rounding_mode: Optional[str] = None, 90 ): 91 if rounding_mode is None: 92 return torch.div(input=input_, other=other_) 93 else: 94 return torch.div( 95 input=input_, other=other_, rounding_mode=rounding_mode 96 ) 97 98 def _test_div_tosa_MI_pipeline( 99 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 100 ): 101 ( 102 ArmTester( 103 module, 104 example_inputs=test_data, 105 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 106 ) 107 .export() 108 .check_count({"torch.ops.aten.div.Tensor": 1}) 109 .check_not(["torch.ops.quantized_decomposed"]) 110 .to_edge() 111 .partition() 112 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 113 .to_executorch() 114 .run_method_and_compare_outputs(inputs=test_data) 115 ) 116 117 def _test_div_tosa_BI_pipeline( 118 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 119 ): 120 ( 121 ArmTester( 122 module, 123 example_inputs=test_data, 124 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 125 ) 126 .quantize() 127 .export() 128 .check_count( 129 {"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1} 130 ) 131 .check(["torch.ops.quantized_decomposed"]) 132 .to_edge() 133 .partition() 134 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 135 .to_executorch() 136 .run_method_and_compare_outputs(inputs=test_data, atol=1, rtol=0.1) 137 ) 138 139 def _test_div_u55_BI_pipeline( 140 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 141 ): 142 ( 143 ArmTester( 144 module, 145 example_inputs=test_data, 146 compile_spec=common.get_u55_compile_spec(), 147 ) 148 .quantize() 149 .export() 150 .check_count( 151 {"torch.ops.aten.reciprocal.default": 1, "torch.ops.aten.mul.Tensor": 1} 152 ) 153 .check(["torch.ops.quantized_decomposed"]) 154 .to_edge() 155 .partition() 156 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 157 .to_executorch() 158 ) 159 160 @parameterized.expand(test_data_suite) 161 def test_div_tosa_MI( 162 self, 163 test_name: str, 164 input_: Union[torch.Tensor, torch.types.Number], 165 other_: Union[torch.Tensor, torch.types.Number], 166 rounding_mode: Optional[str] = None, 167 ): 168 test_data = (input_, other_) 169 self._test_div_tosa_MI_pipeline(self.Div(), test_data) 170 171 @parameterized.expand(test_data_suite) 172 def test_div_tosa_BI( 173 self, 174 test_name: str, 175 input_: Union[torch.Tensor, torch.types.Number], 176 other_: Union[torch.Tensor, torch.types.Number], 177 rounding_mode: Optional[str] = None, 178 ): 179 180 test_data = (input_, other_) 181 self._test_div_tosa_BI_pipeline(self.Div(), test_data) 182 183 @parameterized.expand(test_data_suite) 184 def test_div_u55_BI( 185 self, 186 test_name: str, 187 input_: Union[torch.Tensor, torch.types.Number], 188 other_: Union[torch.Tensor, torch.types.Number], 189 rounding_mode: Optional[str] = None, 190 ): 191 test_data = (input_, other_) 192 self._test_div_u55_BI_pipeline(self.Div(), test_data) 193