1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestMaximum(unittest.TestCase): 14 class Maximum(torch.nn.Module): 15 def __init__(self): 16 super().__init__() 17 18 def forward(self, x, y): 19 return torch.maximum(x, y) 20 21 def _test_maximum(self, inputs): 22 ( 23 Tester(self.Maximum(), inputs) 24 .export() 25 .check_count({"torch.ops.aten.maximum.default": 1}) 26 .to_edge_transform_and_lower() 27 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 28 .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) 29 .to_executorch() 30 .serialize() 31 .run_method_and_compare_outputs() 32 ) 33 34 def test_fp16_maximum(self): 35 inputs = ( 36 torch.randn(2, 3, 4).to(torch.float16), 37 torch.randn(2, 3, 4).to(torch.float16), 38 ) 39 self._test_maximum(inputs) 40 41 def test_fp32_maximum(self): 42 inputs = ( 43 torch.randn(2, 3, 4), 44 torch.randn(2, 3, 4), 45 ) 46 self._test_maximum(inputs) 47 48 def test_fp32_maximum_broadcast(self): 49 inputs = ( 50 torch.randn(2, 3, 4), 51 torch.randn(2, 1, 4), 52 ) 53 ( 54 Tester(self.Maximum(), inputs) 55 .export() 56 .check_count({"torch.ops.aten.maximum.default": 1}) 57 .to_edge_transform_and_lower() 58 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 59 .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) 60 .to_executorch() 61 .serialize() 62 .run_method_and_compare_outputs() 63 ) 64