xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/abs.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestAbs(unittest.TestCase):
14    class Abs(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x):
19            z = torch.abs(x)
20            return z
21
22    def _test_abs(self, inputs, legacy_mode: bool = False):
23        tester = (
24            Tester(self.Abs(), inputs)
25            .export()
26            .check_count({"torch.ops.aten.abs.default": 1})
27        )
28
29        if legacy_mode:
30            tester = tester.to_edge().partition()
31        else:
32            tester = tester.to_edge_transform_and_lower()
33
34        (
35            tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
36            .check_not(["executorch_exir_dialects_edge__ops_aten_abs_default"])
37            .to_executorch()
38            .serialize()
39            .run_method_and_compare_outputs()
40        )
41
42    def test_fp16_abs(self):
43        inputs = (
44            torch.Tensor(
45                [
46                    [0.0, 0.1, 0.5, 0.499],
47                    [-0.6, -0.4, 100.1, -1000.1],
48                ],
49            ).to(torch.float16),
50        )
51        self._test_abs(inputs, legacy_mode=False)
52
53    def test_fp16_abs_legacy_mode(self):
54        inputs = (
55            torch.Tensor(
56                [
57                    [0.0, 0.1, 0.5, 0.499],
58                    [-0.6, -0.4, 100.1, -1000.1],
59                ],
60            ).to(torch.float16),
61        )
62        self._test_abs(inputs, legacy_mode=True)
63
64    def test_fp32_abs(self):
65        inputs = (
66            torch.Tensor(
67                [
68                    [0.0, 0.1, 0.5, 0.499],
69                    [-0.6, -0.4, 100.1, -1000.1],
70                ],
71            ),
72        )
73        self._test_abs(inputs, legacy_mode=False)
74
75    def test_fp32_abs_legacy_mode(self):
76        inputs = (
77            torch.Tensor(
78                [
79                    [0.0, 0.1, 0.5, 0.499],
80                    [-0.6, -0.4, 100.1, -1000.1],
81                ],
82            ),
83        )
84        self._test_abs(inputs, legacy_mode=True)
85