1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.backends.xnnpack.test.tester import Tester 11 12 13class TestMaxDim(unittest.TestCase): 14 class Max(torch.nn.Module): 15 def forward(self, x): 16 max_values_1, max_indices_1 = torch.max(x, dim=2, keepdim=True) 17 max_values_2, max_indices_2 = torch.max(x, dim=3, keepdim=True) 18 return (max_values_1, max_indices_1, max_values_2, max_indices_2) 19 20 class MaxNoIndices(torch.nn.Module): 21 def forward(self, x): 22 max_values_1, _ = torch.max(x, dim=2, keepdim=True) 23 max_values_2, _ = torch.max(x, dim=3, keepdim=True) 24 return (max_values_1, max_values_2) 25 26 def _test_max_dim(self, inputs): 27 ( 28 Tester(self.Max(), inputs) 29 .export() 30 .check_count({"torch.ops.aten.max.dim": 2}) 31 .to_edge_transform_and_lower() 32 .check_not(["torch.ops.higher_order.executorch_call_delegate"]) 33 .check_count({"executorch_exir_dialects_edge__ops_aten_max_dim": 2}) 34 ) 35 36 def _test_max_dim_no_indicies(self, inputs): 37 ( 38 Tester(self.MaxNoIndices(), inputs) 39 .export() 40 .check_count({"torch.ops.aten.max.dim": 2}) 41 .to_edge_transform_and_lower() 42 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 43 .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"]) 44 .to_executorch() 45 .serialize() 46 .run_method_and_compare_outputs() 47 ) 48 49 def test_fp16_max_dim_with_indicies(self): 50 inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),) 51 self._test_max_dim(inputs) 52 53 def test_fp32_max_dim_with_indices(self): 54 inputs = (torch.randn(16, 3, 12, 12),) 55 self._test_max_dim(inputs) 56 57 def test_fp32_max_dim_no_indices(self): 58 inputs = (torch.randn(16, 3, 12, 12),) 59 self._test_max_dim_no_indicies(inputs) 60 61 def test_fp16_max_dim_no_indices(self): 62 inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),) 63 self._test_max_dim_no_indicies(inputs) 64