xref: /aosp_15_r20/external/executorch/exir/tests/test_warnings.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import unittest
9import warnings
10from typing import Any, Callable, Optional
11
12from executorch.exir._warnings import deprecated, experimental, ExperimentalWarning
13
14#
15# Classes
16#
17
18
19class UndecoratedClass:
20    pass
21
22
23@deprecated("DeprecatedClass message")
24class DeprecatedClass:
25    pass
26
27
28@experimental("ExperimentalClass message")
29class ExperimentalClass:
30    pass
31
32
33#
34# Functions
35#
36
37
38def undecorated_function() -> None:
39    pass
40
41
42@deprecated("deprecated_function message")
43def deprecated_function() -> None:
44    pass
45
46
47@experimental("experimental_function message")
48def experimental_function() -> None:
49    pass
50
51
52#
53# Methods
54#
55
56
57class TestClass:
58    def undecorated_method(self) -> None:
59        pass
60
61    @deprecated("deprecated_method message")
62    def deprecated_method(self) -> None:
63        pass
64
65    @experimental("experimental_method message")
66    def experimental_method(self) -> None:
67        pass
68
69
70# NOTE: Variables and fields cannot be decorated.
71
72
73class TestApiLifecycle(unittest.TestCase):
74
75    def is_deprecated(
76        self,
77        callable: Callable[[], Any],  # pyre-ignore[2]: Any type
78        message: Optional[str] = None,
79    ) -> bool:
80        with warnings.catch_warnings(record=True) as w:
81            # Cause all warnings to always be triggered.
82            warnings.simplefilter("always")
83
84            # Try to trigger a warning.
85            callable()
86
87            if not w:
88                # No warnings were triggered.
89                return False
90            if not issubclass(w[-1].category, DeprecationWarning):
91                # There was a warning, but it wasn't a DeprecationWarning.
92                return False
93            if issubclass(w[-1].category, ExperimentalWarning):
94                # ExperimentalWarning is a subclass of DeprecationWarning.
95                return False
96            if message:
97                return message in str(w[-1].message)
98            return True
99
100    def is_experimental(
101        self,
102        callable: Callable[[], Any],  # pyre-ignore[2]: Any type
103        message: Optional[str] = None,
104    ) -> bool:
105        with warnings.catch_warnings(record=True) as w:
106            # Cause all warnings to always be triggered.
107            warnings.simplefilter("always")
108
109            # Try to trigger a warning.
110            callable()
111
112            if not w:
113                # No warnings were triggered.
114                return False
115            if not issubclass(w[-1].category, ExperimentalWarning):
116                # There was a warning, but it wasn't an ExperimentalWarning.
117                return False
118            if message:
119                return message in str(w[-1].message)
120            return True
121
122    def test_undecorated_class(self) -> None:
123        self.assertFalse(self.is_deprecated(UndecoratedClass))
124        self.assertFalse(self.is_experimental(UndecoratedClass))
125
126    def test_deprecated_class(self) -> None:
127        self.assertTrue(self.is_deprecated(DeprecatedClass, "DeprecatedClass message"))
128        self.assertFalse(self.is_experimental(DeprecatedClass))
129
130    def test_experimental_class(self) -> None:
131        self.assertFalse(self.is_deprecated(ExperimentalClass))
132        self.assertTrue(
133            self.is_experimental(ExperimentalClass, "ExperimentalClass message")
134        )
135
136    def test_undecorated_function(self) -> None:
137        self.assertFalse(self.is_deprecated(undecorated_function))
138        self.assertFalse(self.is_experimental(undecorated_function))
139
140    def test_deprecated_function(self) -> None:
141        self.assertTrue(
142            self.is_deprecated(deprecated_function, "deprecated_function message")
143        )
144        self.assertFalse(self.is_experimental(deprecated_function))
145
146    def test_experimental_function(self) -> None:
147        self.assertFalse(self.is_deprecated(experimental_function))
148        self.assertTrue(
149            self.is_experimental(experimental_function, "experimental_function message")
150        )
151
152    def test_undecorated_method(self) -> None:
153        tc = TestClass()
154        self.assertFalse(self.is_deprecated(tc.undecorated_method))
155        self.assertFalse(self.is_experimental(tc.undecorated_method))
156
157    def test_deprecated_method(self) -> None:
158        tc = TestClass()
159        self.assertTrue(
160            self.is_deprecated(tc.deprecated_method, "deprecated_method message")
161        )
162        self.assertFalse(self.is_experimental(tc.deprecated_method))
163
164    def test_experimental_method(self) -> None:
165        tc = TestClass()
166        self.assertFalse(self.is_deprecated(tc.experimental_method))
167        self.assertTrue(
168            self.is_experimental(tc.experimental_method, "experimental_method message")
169        )
170