# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from executorch.backends.xnnpack.test.tester import Tester class TestHardswish(unittest.TestCase): class Hardswish(torch.nn.Module): def __init__(self): super().__init__() self.hardswish = torch.nn.Hardswish() def forward(self, x): return self.hardswish(x) class HardswishFunctional(torch.nn.Module): def forward(self, x): return torch.nn.functional.hardswish(x) def _test_hardswish(self, inputs): ( Tester(self.Hardswish(), inputs) .export() .check_count({"torch.ops.aten.hardswish.default": 1}) .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ "executorch_exir_dialects_edge__ops_aten_hardswish_default", ] ) .to_executorch() .serialize() .run_method_and_compare_outputs() ) def test_fp16_hardswish(self): inputs = (torch.randn(1, 3, 3).to(torch.float16),) self._test_hardswish(inputs) def test_fp32_hardswish(self): inputs = (torch.randn(1, 3, 3),) self._test_hardswish(inputs) def test_fp32_hardswish_functional(self): inputs = (torch.randn(1, 3, 3),) ( Tester(self.HardswishFunctional(), inputs) .export() .check_count({"torch.ops.aten.hardswish.default": 1}) .to_edge_transform_and_lower() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .check_not( [ "executorch_exir_dialects_edge__ops_aten_hardswish_default", ] ) .to_executorch() .serialize() .run_method_and_compare_outputs() )