1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestBMM(unittest.TestCase): 14 class BMM(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x, y): 19 return torch.bmm(x, y) 20 21 def _test_bmm(self, inputs): 22 ( 23 Tester(self.BMM(), inputs) 24 .export() 25 .check_count({"torch.ops.aten.bmm.default": 1}) 26 .to_edge_transform_and_lower() 27 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 28 .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) 29 .to_executorch() 30 .serialize() 31 .run_method_and_compare_outputs() 32 ) 33 34 def test_fp16_bmm(self): 35 inputs = ( 36 torch.randn(2, 3, 4).to(torch.float16), 37 torch.randn(2, 4, 6).to(torch.float16), 38 ) 39 self._test_bmm(inputs) 40 41 def test_fp32_bmm(self): 42 inputs = ( 43 torch.randn(2, 3, 4), 44 torch.randn(2, 4, 6), 45 ) 46 self._test_bmm(inputs) 47