1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6import warnings 7from typing import Dict, List, Optional 8 9import torch 10 11 12# Make the helper files in test/ importable 13pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 14sys.path.append(pytorch_test_dir) 15from torch.testing._internal.jit_utils import JitTestCase 16 17 18if __name__ == "__main__": 19 raise RuntimeError( 20 "This test file is not meant to be run directly, use:\n\n" 21 "\tpython test/test_jit.py TESTNAME\n\n" 22 "instead." 23 ) 24 25 26class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): 27 # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact, 28 # reassigning a non-empty Tuple to an attribute previously typed 29 # as containing an empty Tuple SHOULD fail. See note in `_check.py` 30 31 def test_annotated_falsy_base_type(self): 32 class M(torch.nn.Module): 33 def __init__(self) -> None: 34 super().__init__() 35 self.x: int = 0 36 37 def forward(self, x: int): 38 self.x = x 39 return 1 40 41 with warnings.catch_warnings(record=True) as w: 42 self.checkModule(M(), (1,)) 43 assert len(w) == 0 44 45 def test_annotated_nonempty_container(self): 46 class M(torch.nn.Module): 47 def __init__(self) -> None: 48 super().__init__() 49 self.x: List[int] = [1, 2, 3] 50 51 def forward(self, x: List[int]): 52 self.x = x 53 return 1 54 55 with warnings.catch_warnings(record=True) as w: 56 self.checkModule(M(), ([1, 2, 3],)) 57 assert len(w) == 0 58 59 def test_annotated_empty_tensor(self): 60 class M(torch.nn.Module): 61 def __init__(self) -> None: 62 super().__init__() 63 self.x: torch.Tensor = torch.empty(0) 64 65 def forward(self, x: torch.Tensor): 66 self.x = x 67 return self.x 68 69 with warnings.catch_warnings(record=True) as w: 70 self.checkModule(M(), (torch.rand(2, 3),)) 71 assert len(w) == 0 72 73 def test_annotated_with_jit_attribute(self): 74 class M(torch.nn.Module): 75 def __init__(self) -> None: 76 super().__init__() 77 self.x = torch.jit.Attribute([], List[int]) 78 79 def forward(self, x: List[int]): 80 self.x = x 81 return self.x 82 83 with warnings.catch_warnings(record=True) as w: 84 self.checkModule(M(), ([1, 2, 3],)) 85 assert len(w) == 0 86 87 def test_annotated_class_level_annotation_only(self): 88 class M(torch.nn.Module): 89 x: List[int] 90 91 def __init__(self) -> None: 92 super().__init__() 93 self.x = [] 94 95 def forward(self, y: List[int]): 96 self.x = y 97 return self.x 98 99 with warnings.catch_warnings(record=True) as w: 100 self.checkModule(M(), ([1, 2, 3],)) 101 assert len(w) == 0 102 103 def test_annotated_class_level_annotation_and_init_annotation(self): 104 class M(torch.nn.Module): 105 x: List[int] 106 107 def __init__(self) -> None: 108 super().__init__() 109 self.x: List[int] = [] 110 111 def forward(self, y: List[int]): 112 self.x = y 113 return self.x 114 115 with warnings.catch_warnings(record=True) as w: 116 self.checkModule(M(), ([1, 2, 3],)) 117 assert len(w) == 0 118 119 def test_annotated_class_level_jit_annotation(self): 120 class M(torch.nn.Module): 121 x: List[int] 122 123 def __init__(self) -> None: 124 super().__init__() 125 self.x: List[int] = torch.jit.annotate(List[int], []) 126 127 def forward(self, y: List[int]): 128 self.x = y 129 return self.x 130 131 with warnings.catch_warnings(record=True) as w: 132 self.checkModule(M(), ([1, 2, 3],)) 133 assert len(w) == 0 134 135 def test_annotated_empty_list(self): 136 class M(torch.nn.Module): 137 def __init__(self) -> None: 138 super().__init__() 139 self.x: List[int] = [] 140 141 def forward(self, x: List[int]): 142 self.x = x 143 return 1 144 145 with self.assertRaisesRegexWithHighlight( 146 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 147 ): 148 with self.assertWarnsRegex( 149 UserWarning, 150 "doesn't support " 151 "instance-level annotations on " 152 "empty non-base types", 153 ): 154 torch.jit.script(M()) 155 156 @unittest.skipIf( 157 sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" 158 ) 159 def test_annotated_empty_list_lowercase(self): 160 class M(torch.nn.Module): 161 def __init__(self) -> None: 162 super().__init__() 163 self.x: list[int] = [] 164 165 def forward(self, x: list[int]): 166 self.x = x 167 return 1 168 169 with self.assertRaisesRegexWithHighlight( 170 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 171 ): 172 with self.assertWarnsRegex( 173 UserWarning, 174 "doesn't support " 175 "instance-level annotations on " 176 "empty non-base types", 177 ): 178 torch.jit.script(M()) 179 180 def test_annotated_empty_dict(self): 181 class M(torch.nn.Module): 182 def __init__(self) -> None: 183 super().__init__() 184 self.x: Dict[str, int] = {} 185 186 def forward(self, x: Dict[str, int]): 187 self.x = x 188 return 1 189 190 with self.assertRaisesRegexWithHighlight( 191 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 192 ): 193 with self.assertWarnsRegex( 194 UserWarning, 195 "doesn't support " 196 "instance-level annotations on " 197 "empty non-base types", 198 ): 199 torch.jit.script(M()) 200 201 @unittest.skipIf( 202 sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" 203 ) 204 def test_annotated_empty_dict_lowercase(self): 205 class M(torch.nn.Module): 206 def __init__(self) -> None: 207 super().__init__() 208 self.x: dict[str, int] = {} 209 210 def forward(self, x: dict[str, int]): 211 self.x = x 212 return 1 213 214 with self.assertRaisesRegexWithHighlight( 215 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 216 ): 217 with self.assertWarnsRegex( 218 UserWarning, 219 "doesn't support " 220 "instance-level annotations on " 221 "empty non-base types", 222 ): 223 torch.jit.script(M()) 224 225 def test_annotated_empty_optional(self): 226 class M(torch.nn.Module): 227 def __init__(self) -> None: 228 super().__init__() 229 self.x: Optional[str] = None 230 231 def forward(self, x: Optional[str]): 232 self.x = x 233 return 1 234 235 with self.assertRaisesRegexWithHighlight( 236 RuntimeError, "Wrong type for attribute assignment", "self.x = x" 237 ): 238 with self.assertWarnsRegex( 239 UserWarning, 240 "doesn't support " 241 "instance-level annotations on " 242 "empty non-base types", 243 ): 244 torch.jit.script(M()) 245 246 def test_annotated_with_jit_empty_list(self): 247 class M(torch.nn.Module): 248 def __init__(self) -> None: 249 super().__init__() 250 self.x = torch.jit.annotate(List[int], []) 251 252 def forward(self, x: List[int]): 253 self.x = x 254 return 1 255 256 with self.assertRaisesRegexWithHighlight( 257 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 258 ): 259 with self.assertWarnsRegex( 260 UserWarning, 261 "doesn't support " 262 "instance-level annotations on " 263 "empty non-base types", 264 ): 265 torch.jit.script(M()) 266 267 @unittest.skipIf( 268 sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" 269 ) 270 def test_annotated_with_jit_empty_list_lowercase(self): 271 class M(torch.nn.Module): 272 def __init__(self) -> None: 273 super().__init__() 274 self.x = torch.jit.annotate(list[int], []) 275 276 def forward(self, x: list[int]): 277 self.x = x 278 return 1 279 280 with self.assertRaisesRegexWithHighlight( 281 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 282 ): 283 with self.assertWarnsRegex( 284 UserWarning, 285 "doesn't support " 286 "instance-level annotations on " 287 "empty non-base types", 288 ): 289 torch.jit.script(M()) 290 291 def test_annotated_with_jit_empty_dict(self): 292 class M(torch.nn.Module): 293 def __init__(self) -> None: 294 super().__init__() 295 self.x = torch.jit.annotate(Dict[str, int], {}) 296 297 def forward(self, x: Dict[str, int]): 298 self.x = x 299 return 1 300 301 with self.assertRaisesRegexWithHighlight( 302 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 303 ): 304 with self.assertWarnsRegex( 305 UserWarning, 306 "doesn't support " 307 "instance-level annotations on " 308 "empty non-base types", 309 ): 310 torch.jit.script(M()) 311 312 @unittest.skipIf( 313 sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)" 314 ) 315 def test_annotated_with_jit_empty_dict_lowercase(self): 316 class M(torch.nn.Module): 317 def __init__(self) -> None: 318 super().__init__() 319 self.x = torch.jit.annotate(dict[str, int], {}) 320 321 def forward(self, x: dict[str, int]): 322 self.x = x 323 return 1 324 325 with self.assertRaisesRegexWithHighlight( 326 RuntimeError, "Tried to set nonexistent attribute", "self.x = x" 327 ): 328 with self.assertWarnsRegex( 329 UserWarning, 330 "doesn't support " 331 "instance-level annotations on " 332 "empty non-base types", 333 ): 334 torch.jit.script(M()) 335 336 def test_annotated_with_jit_empty_optional(self): 337 class M(torch.nn.Module): 338 def __init__(self) -> None: 339 super().__init__() 340 self.x = torch.jit.annotate(Optional[str], None) 341 342 def forward(self, x: Optional[str]): 343 self.x = x 344 return 1 345 346 with self.assertRaisesRegexWithHighlight( 347 RuntimeError, "Wrong type for attribute assignment", "self.x = x" 348 ): 349 with self.assertWarnsRegex( 350 UserWarning, 351 "doesn't support " 352 "instance-level annotations on " 353 "empty non-base types", 354 ): 355 torch.jit.script(M()) 356 357 def test_annotated_with_torch_jit_import(self): 358 from torch import jit 359 360 class M(torch.nn.Module): 361 def __init__(self) -> None: 362 super().__init__() 363 self.x = jit.annotate(Optional[str], None) 364 365 def forward(self, x: Optional[str]): 366 self.x = x 367 return 1 368 369 with self.assertRaisesRegexWithHighlight( 370 RuntimeError, "Wrong type for attribute assignment", "self.x = x" 371 ): 372 with self.assertWarnsRegex( 373 UserWarning, 374 "doesn't support " 375 "instance-level annotations on " 376 "empty non-base types", 377 ): 378 torch.jit.script(M()) 379