xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/bmm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestBMM(unittest.TestCase):
14    class BMM(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x, y):
19            return torch.bmm(x, y)
20
21    def _test_bmm(self, inputs):
22        (
23            Tester(self.BMM(), inputs)
24            .export()
25            .check_count({"torch.ops.aten.bmm.default": 1})
26            .to_edge_transform_and_lower()
27            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
28            .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
29            .to_executorch()
30            .serialize()
31            .run_method_and_compare_outputs()
32        )
33
34    def test_fp16_bmm(self):
35        inputs = (
36            torch.randn(2, 3, 4).to(torch.float16),
37            torch.randn(2, 4, 6).to(torch.float16),
38        )
39        self._test_bmm(inputs)
40
41    def test_fp32_bmm(self):
42        inputs = (
43            torch.randn(2, 3, 4),
44            torch.randn(2, 4, 6),
45        )
46        self._test_bmm(inputs)
47