1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestMul(unittest.TestCase): 14 class Mul(torch.nn.Module): 15 def forward(self, x, y): 16 z = x * y 17 return z 18 19 class Mul2(torch.nn.Module): 20 def forward(self, x): 21 z = x * x 22 return z 23 24 class MulFunctional(torch.nn.Module): 25 def forward(self, x, y): 26 z = torch.mul(x, y) * torch.functional.torch.mul(x, y) 27 return z 28 29 class MulRelu(torch.nn.Module): 30 def forward(self, x, y): 31 z = x * y 32 return torch.nn.functional.relu(z) 33 34 def _test_mul(self, inputs): 35 ( 36 Tester(self.Mul(), inputs) 37 .export() 38 .check_count({"torch.ops.aten.mul.Tensor": 1}) 39 .to_edge_transform_and_lower() 40 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 41 .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"]) 42 .to_executorch() 43 .serialize() 44 .run_method_and_compare_outputs() 45 ) 46 47 def test_fp16_mul(self): 48 inputs = ( 49 torch.randn((1, 3)).to(torch.float16), 50 torch.randn((4, 3)).to(torch.float16), 51 ) 52 self._test_mul(inputs) 53 54 def test_fp32_mul(self): 55 inputs = (torch.randn((1, 3)), torch.randn((4, 3))) 56 self._test_mul(inputs) 57 58 def test_qs8_mul(self): 59 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)) 60 ( 61 Tester(self.Mul(), inputs) 62 .quantize() 63 .export() 64 .check_count({"torch.ops.aten.mul.Tensor": 1}) 65 .check(["torch.ops.quantized_decomposed"]) 66 .to_edge_transform_and_lower() 67 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 68 .check_not( 69 [ 70 "executorch_exir_dialects_edge__ops_aten_mul_Tensor", 71 "torch.ops.quantized_decomposed", 72 ] 73 ) 74 .to_executorch() 75 .serialize() 76 .run_method_and_compare_outputs() 77 ) 78 79 def test_qs8_mul2(self): 80 inputs = (torch.randn(1, 1, 4, 4),) 81 ( 82 Tester(self.Mul2(), inputs) 83 .quantize() 84 .export() 85 .check_count({"torch.ops.aten.mul.Tensor": 1}) 86 .check(["torch.ops.quantized_decomposed"]) 87 .to_edge_transform_and_lower() 88 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 89 .check_not( 90 [ 91 "executorch_exir_dialects_edge__ops_aten_mul_Tensor", 92 "torch.ops.quantized_decomposed", 93 ] 94 ) 95 .to_executorch() 96 .serialize() 97 .run_method_and_compare_outputs() 98 ) 99 100 def test_qs8_mul_functional(self): 101 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 102 ( 103 Tester(self.MulFunctional(), inputs) 104 .quantize() 105 .export() 106 .check_count({"torch.ops.aten.mul.Tensor": 3}) 107 .check(["torch.ops.quantized_decomposed"]) 108 .to_edge_transform_and_lower() 109 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 110 .check_not( 111 [ 112 "executorch_exir_dialects_edge__ops_aten_mul_Tensor", 113 "torch.ops.quantized_decomposed", 114 ] 115 ) 116 .to_executorch() 117 .serialize() 118 .run_method_and_compare_outputs() 119 ) 120 121 def test_qs8_mul_relu(self): 122 inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4)) 123 ( 124 Tester(self.MulRelu(), inputs) 125 .quantize() 126 .export() 127 .check_count( 128 { 129 "torch.ops.aten.mul.Tensor": 1, 130 "torch.ops.aten.relu.default": 1, 131 } 132 ) 133 .check(["torch.ops.quantized_decomposed"]) 134 .to_edge_transform_and_lower() 135 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 136 .check_not( 137 [ 138 "executorch_exir_dialects_edge__ops_aten_mul_Tensor", 139 "executorch_exir_dialects_edge__ops_aten_relu_default", 140 "torch.ops.quantized_decomposed", 141 ] 142 ) 143 .to_executorch() 144 .serialize() 145 .run_method_and_compare_outputs() 146 ) 147