1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import unittest 9import warnings 10from typing import Any, Callable, Optional 11 12from executorch.exir._warnings import deprecated, experimental, ExperimentalWarning 13 14# 15# Classes 16# 17 18 19class UndecoratedClass: 20 pass 21 22 23@deprecated("DeprecatedClass message") 24class DeprecatedClass: 25 pass 26 27 28@experimental("ExperimentalClass message") 29class ExperimentalClass: 30 pass 31 32 33# 34# Functions 35# 36 37 38def undecorated_function() -> None: 39 pass 40 41 42@deprecated("deprecated_function message") 43def deprecated_function() -> None: 44 pass 45 46 47@experimental("experimental_function message") 48def experimental_function() -> None: 49 pass 50 51 52# 53# Methods 54# 55 56 57class TestClass: 58 def undecorated_method(self) -> None: 59 pass 60 61 @deprecated("deprecated_method message") 62 def deprecated_method(self) -> None: 63 pass 64 65 @experimental("experimental_method message") 66 def experimental_method(self) -> None: 67 pass 68 69 70# NOTE: Variables and fields cannot be decorated. 71 72 73class TestApiLifecycle(unittest.TestCase): 74 75 def is_deprecated( 76 self, 77 callable: Callable[[], Any], # pyre-ignore[2]: Any type 78 message: Optional[str] = None, 79 ) -> bool: 80 with warnings.catch_warnings(record=True) as w: 81 # Cause all warnings to always be triggered. 82 warnings.simplefilter("always") 83 84 # Try to trigger a warning. 85 callable() 86 87 if not w: 88 # No warnings were triggered. 89 return False 90 if not issubclass(w[-1].category, DeprecationWarning): 91 # There was a warning, but it wasn't a DeprecationWarning. 92 return False 93 if issubclass(w[-1].category, ExperimentalWarning): 94 # ExperimentalWarning is a subclass of DeprecationWarning. 95 return False 96 if message: 97 return message in str(w[-1].message) 98 return True 99 100 def is_experimental( 101 self, 102 callable: Callable[[], Any], # pyre-ignore[2]: Any type 103 message: Optional[str] = None, 104 ) -> bool: 105 with warnings.catch_warnings(record=True) as w: 106 # Cause all warnings to always be triggered. 107 warnings.simplefilter("always") 108 109 # Try to trigger a warning. 110 callable() 111 112 if not w: 113 # No warnings were triggered. 114 return False 115 if not issubclass(w[-1].category, ExperimentalWarning): 116 # There was a warning, but it wasn't an ExperimentalWarning. 117 return False 118 if message: 119 return message in str(w[-1].message) 120 return True 121 122 def test_undecorated_class(self) -> None: 123 self.assertFalse(self.is_deprecated(UndecoratedClass)) 124 self.assertFalse(self.is_experimental(UndecoratedClass)) 125 126 def test_deprecated_class(self) -> None: 127 self.assertTrue(self.is_deprecated(DeprecatedClass, "DeprecatedClass message")) 128 self.assertFalse(self.is_experimental(DeprecatedClass)) 129 130 def test_experimental_class(self) -> None: 131 self.assertFalse(self.is_deprecated(ExperimentalClass)) 132 self.assertTrue( 133 self.is_experimental(ExperimentalClass, "ExperimentalClass message") 134 ) 135 136 def test_undecorated_function(self) -> None: 137 self.assertFalse(self.is_deprecated(undecorated_function)) 138 self.assertFalse(self.is_experimental(undecorated_function)) 139 140 def test_deprecated_function(self) -> None: 141 self.assertTrue( 142 self.is_deprecated(deprecated_function, "deprecated_function message") 143 ) 144 self.assertFalse(self.is_experimental(deprecated_function)) 145 146 def test_experimental_function(self) -> None: 147 self.assertFalse(self.is_deprecated(experimental_function)) 148 self.assertTrue( 149 self.is_experimental(experimental_function, "experimental_function message") 150 ) 151 152 def test_undecorated_method(self) -> None: 153 tc = TestClass() 154 self.assertFalse(self.is_deprecated(tc.undecorated_method)) 155 self.assertFalse(self.is_experimental(tc.undecorated_method)) 156 157 def test_deprecated_method(self) -> None: 158 tc = TestClass() 159 self.assertTrue( 160 self.is_deprecated(tc.deprecated_method, "deprecated_method message") 161 ) 162 self.assertFalse(self.is_experimental(tc.deprecated_method)) 163 164 def test_experimental_method(self) -> None: 165 tc = TestClass() 166 self.assertFalse(self.is_deprecated(tc.experimental_method)) 167 self.assertTrue( 168 self.is_experimental(tc.experimental_method, "experimental_method message") 169 ) 170