xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/max_dim.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestMaxDim(unittest.TestCase):
14    class Max(torch.nn.Module):
15        def forward(self, x):
16            max_values_1, max_indices_1 = torch.max(x, dim=2, keepdim=True)
17            max_values_2, max_indices_2 = torch.max(x, dim=3, keepdim=True)
18            return (max_values_1, max_indices_1, max_values_2, max_indices_2)
19
20    class MaxNoIndices(torch.nn.Module):
21        def forward(self, x):
22            max_values_1, _ = torch.max(x, dim=2, keepdim=True)
23            max_values_2, _ = torch.max(x, dim=3, keepdim=True)
24            return (max_values_1, max_values_2)
25
26    def _test_max_dim(self, inputs):
27        (
28            Tester(self.Max(), inputs)
29            .export()
30            .check_count({"torch.ops.aten.max.dim": 2})
31            .to_edge_transform_and_lower()
32            .check_not(["torch.ops.higher_order.executorch_call_delegate"])
33            .check_count({"executorch_exir_dialects_edge__ops_aten_max_dim": 2})
34        )
35
36    def _test_max_dim_no_indicies(self, inputs):
37        (
38            Tester(self.MaxNoIndices(), inputs)
39            .export()
40            .check_count({"torch.ops.aten.max.dim": 2})
41            .to_edge_transform_and_lower()
42            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
43            .check_not(["executorch_exir_dialects_edge__ops_aten_max_dim"])
44            .to_executorch()
45            .serialize()
46            .run_method_and_compare_outputs()
47        )
48
49    def test_fp16_max_dim_with_indicies(self):
50        inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),)
51        self._test_max_dim(inputs)
52
53    def test_fp32_max_dim_with_indices(self):
54        inputs = (torch.randn(16, 3, 12, 12),)
55        self._test_max_dim(inputs)
56
57    def test_fp32_max_dim_no_indices(self):
58        inputs = (torch.randn(16, 3, 12, 12),)
59        self._test_max_dim_no_indicies(inputs)
60
61    def test_fp16_max_dim_no_indices(self):
62        inputs = (torch.randn(16, 3, 12, 12).to(torch.float16),)
63        self._test_max_dim_no_indicies(inputs)
64