1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestHardTanh(unittest.TestCase): 14 class HardTanh(torch.nn.Module): 15 def __init__(self, min_val=-1.0, max_val=1.0): 16 super().__init__() 17 self.min_val = min_val 18 self.max_val = max_val 19 20 def forward(self, x): 21 y = x + x 22 z = torch.nn.Hardtanh(self.min_val, self.max_val)(y) 23 return z 24 25 def test_fp32_hardtanh(self): 26 inputs_sets = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)] 27 for input in inputs_sets: 28 ( 29 Tester(self.HardTanh(), (input,)) 30 .export() 31 .check_count({"torch.ops.aten.hardtanh.default": 1}) 32 .to_edge_transform_and_lower() 33 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 34 .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) 35 .to_executorch() 36 .serialize() 37 .run_method_and_compare_outputs() 38 ) 39 40 def test_fp32_hardtanh_bound(self): 41 inputs_sets = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)] 42 for input in inputs_sets: 43 ( 44 Tester(self.HardTanh(-2.0, 2.0), (input,)) 45 .export() 46 .check_count({"torch.ops.aten.hardtanh.default": 1}) 47 .to_edge_transform_and_lower() 48 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 49 .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) 50 .to_executorch() 51 .serialize() 52 .run_method_and_compare_outputs() 53 ) 54 55 def test_qs8_hardtanh(self): 56 inputs_sets = [torch.randn(2, 3, 2), torch.randn(2, 1, 2), torch.randn(2, 3)] 57 for input in inputs_sets: 58 ( 59 Tester(self.HardTanh(), (input,)) 60 .quantize() 61 .export() 62 .check_node_count( 63 { 64 # Expect three quantize ops - one for input, hardtanh, and add. 65 torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, 66 torch.ops.aten.hardtanh.default: 1, 67 } 68 ) 69 .to_edge_transform_and_lower() 70 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 71 .check_not( 72 [ 73 "executorch_exir_dialects_edge__ops_aten_hardtanh_default", 74 "torch.ops.quantized_decomposed", 75 ] 76 ) 77 .to_executorch() 78 .serialize() 79 .run_method_and_compare_outputs() 80 ) 81