xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/multiply.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestMul(unittest.TestCase):
14    class Mul(torch.nn.Module):
15        def forward(self, x, y):
16            z = x * y
17            return z
18
19    class Mul2(torch.nn.Module):
20        def forward(self, x):
21            z = x * x
22            return z
23
24    class MulFunctional(torch.nn.Module):
25        def forward(self, x, y):
26            z = torch.mul(x, y) * torch.functional.torch.mul(x, y)
27            return z
28
29    class MulRelu(torch.nn.Module):
30        def forward(self, x, y):
31            z = x * y
32            return torch.nn.functional.relu(z)
33
34    def _test_mul(self, inputs):
35        (
36            Tester(self.Mul(), inputs)
37            .export()
38            .check_count({"torch.ops.aten.mul.Tensor": 1})
39            .to_edge_transform_and_lower()
40            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
41            .check_not(["executorch_exir_dialects_edge__ops_aten_mul_Tensor"])
42            .to_executorch()
43            .serialize()
44            .run_method_and_compare_outputs()
45        )
46
47    def test_fp16_mul(self):
48        inputs = (
49            torch.randn((1, 3)).to(torch.float16),
50            torch.randn((4, 3)).to(torch.float16),
51        )
52        self._test_mul(inputs)
53
54    def test_fp32_mul(self):
55        inputs = (torch.randn((1, 3)), torch.randn((4, 3)))
56        self._test_mul(inputs)
57
58    def test_qs8_mul(self):
59        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1))
60        (
61            Tester(self.Mul(), inputs)
62            .quantize()
63            .export()
64            .check_count({"torch.ops.aten.mul.Tensor": 1})
65            .check(["torch.ops.quantized_decomposed"])
66            .to_edge_transform_and_lower()
67            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
68            .check_not(
69                [
70                    "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
71                    "torch.ops.quantized_decomposed",
72                ]
73            )
74            .to_executorch()
75            .serialize()
76            .run_method_and_compare_outputs()
77        )
78
79    def test_qs8_mul2(self):
80        inputs = (torch.randn(1, 1, 4, 4),)
81        (
82            Tester(self.Mul2(), inputs)
83            .quantize()
84            .export()
85            .check_count({"torch.ops.aten.mul.Tensor": 1})
86            .check(["torch.ops.quantized_decomposed"])
87            .to_edge_transform_and_lower()
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .check_not(
90                [
91                    "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
92                    "torch.ops.quantized_decomposed",
93                ]
94            )
95            .to_executorch()
96            .serialize()
97            .run_method_and_compare_outputs()
98        )
99
100    def test_qs8_mul_functional(self):
101        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
102        (
103            Tester(self.MulFunctional(), inputs)
104            .quantize()
105            .export()
106            .check_count({"torch.ops.aten.mul.Tensor": 3})
107            .check(["torch.ops.quantized_decomposed"])
108            .to_edge_transform_and_lower()
109            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
110            .check_not(
111                [
112                    "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
113                    "torch.ops.quantized_decomposed",
114                ]
115            )
116            .to_executorch()
117            .serialize()
118            .run_method_and_compare_outputs()
119        )
120
121    def test_qs8_mul_relu(self):
122        inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
123        (
124            Tester(self.MulRelu(), inputs)
125            .quantize()
126            .export()
127            .check_count(
128                {
129                    "torch.ops.aten.mul.Tensor": 1,
130                    "torch.ops.aten.relu.default": 1,
131                }
132            )
133            .check(["torch.ops.quantized_decomposed"])
134            .to_edge_transform_and_lower()
135            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
136            .check_not(
137                [
138                    "executorch_exir_dialects_edge__ops_aten_mul_Tensor",
139                    "executorch_exir_dialects_edge__ops_aten_relu_default",
140                    "torch.ops.quantized_decomposed",
141                ]
142            )
143            .to_executorch()
144            .serialize()
145            .run_method_and_compare_outputs()
146        )
147