xref: /aosp_15_r20/external/pytorch/test/export/test_functionalized_assertions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2import torch
3from torch.testing._internal.common_utils import run_tests, TestCase
4
5
6class TestFuntionalAssertions(TestCase):
7    def test_functional_assert_async_msg(self) -> None:
8        dep_token = torch.ops.aten._make_dep_token()
9        self.assertEqual(
10            torch.ops.aten._functional_assert_async.msg(
11                torch.tensor(1), "test msg", dep_token
12            ),
13            dep_token,
14        )
15        with self.assertRaisesRegex(RuntimeError, "test msg"):
16            torch.ops.aten._functional_assert_async.msg(
17                torch.tensor(0), "test msg", dep_token
18            ),
19
20    def test_functional_sym_constrain_range(self) -> None:
21        dep_token = torch.ops.aten._make_dep_token()
22        self.assertEqual(
23            torch.ops.aten._functional_sym_constrain_range(
24                3, min=2, max=5, dep_token=dep_token
25            ),
26            dep_token,
27        )
28
29
30if __name__ == "__main__":
31    run_tests()
32