# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from executorch.backends.xnnpack.test.tester import Tester class TestMaximum(unittest.TestCase): class Maximum(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.maximum(x, y) def _test_maximum(self, inputs): ( Tester(self.Maximum(), inputs) .export() .check_count({"torch.ops.aten.maximum.default": 1}) .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() .run_method_and_compare_outputs() ) def test_fp16_maximum(self): inputs = ( torch.randn(2, 3, 4).to(torch.float16), torch.randn(2, 3, 4).to(torch.float16), ) self._test_maximum(inputs) def test_fp32_maximum(self): inputs = ( torch.randn(2, 3, 4), torch.randn(2, 3, 4), ) self._test_maximum(inputs) def test_fp32_maximum_broadcast(self): inputs = ( torch.randn(2, 3, 4), torch.randn(2, 1, 4), ) ( Tester(self.Maximum(), inputs) .export() .check_count({"torch.ops.aten.maximum.default": 1}) .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"]) .to_executorch() .serialize() .run_method_and_compare_outputs() )