1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from collections import namedtuple 6from typing import Dict, List, NamedTuple, Tuple 7 8import torch 9from torch.testing._internal.common_utils import IS_WINDOWS 10from torch.testing._internal.jit_utils import JitTestCase, make_global 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestTyping(JitTestCase): 26 def test_dict_in_not_in(self): 27 def test_in_dict(x): 28 # type: (Dict[str, int]) -> bool 29 return "hi" in x 30 31 self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},)) 32 self.checkScript(test_in_dict, ({"bye": 3},)) 33 34 # Check evaluation order 35 @torch.jit.script 36 def a(): 37 print("a") 38 return 3 39 40 @torch.jit.script 41 def b(): 42 print("b") 43 return {3: 2, 4: 1} 44 45 @torch.jit.script 46 def fn(): 47 return a() in b() 48 49 with self.capture_stdout() as captured: 50 self.assertTrue(fn()) 51 if not IS_WINDOWS: 52 # no stdout capturing on windows 53 self.assertEqual(captured[0], "a\nb\n") 54 55 def test_not_in_dict(a): 56 # type: (Dict[str, int]) -> bool 57 if "hello" not in a: 58 return False 59 else: 60 return True 61 62 self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},)) 63 self.checkScript(test_not_in_dict, ({"world": 2},)) 64 65 def test_dict_tensor_key(a, t): 66 # type: (Dict[Tensor, int], Tensor) -> bool 67 if t in a: 68 return True 69 else: 70 return False 71 72 inp1 = torch.tensor(3) 73 inp2 = torch.tensor(5) 74 dict_a = {inp1: 1, inp2: 3} 75 self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4))) 76 self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3))) 77 self.checkScript(test_dict_tensor_key, (dict_a, inp1)) 78 self.checkScript(test_dict_tensor_key, (dict_a, inp2)) 79 80 def test_list_type_refinement_annotation_element_mismatch(self): 81 def fn(): 82 l: List[int] = [1, 2, "foo", 3] 83 return l 84 85 with self.assertRaisesRegex( 86 RuntimeError, 87 "List type annotation" 88 r" `List\[int\]` did not match the " 89 "types of the given list elements", 90 ): 91 torch.jit.script(fn) 92 93 def test_dict_type_refinement_annotation_key_mismatch(self): 94 def fn(): 95 l1 = [1, 2, "foo", 3] 96 l2 = ["foo", "bar", "baz", "qux"] 97 d: Dict[int, str] = dict(zip(l1, l2)) 98 return d 99 100 with self.assertRaisesRegex( 101 RuntimeError, 102 "Dicts may only " 103 "contain homogeneous keys, but the " 104 "type of the first generated key " 105 r"was Union\[int, str\]", 106 ): 107 torch.jit.script(fn) 108 109 def test_dict_type_refinement_annotation_value_mismatch(self): 110 def fn(): 111 l1 = ["foo", "bar", "baz", "qux"] 112 l2 = [1, 2, "foo", 3] 113 d: Dict[str, int] = dict(zip(l1, l2)) 114 return d 115 116 with self.assertRaisesRegex( 117 RuntimeError, 118 "Dict type annotation" 119 r" `Dict\[str, int\]` did not match" 120 " the type of an actual value type" 121 r" `Union\[int, str\]`", 122 ): 123 torch.jit.script(fn) 124 125 def test_dict_invalid_annotations(self): 126 # Check for invalid value type annotation 127 def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]): 128 return 129 130 with self.assertRaisesRegex(ValueError, "Unknown type annotation"): 131 torch.jit.script(wrong_value_type) 132 133 # Check for invalid key type annotation 134 def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): 135 return 136 137 with self.assertRaisesRegex(ValueError, "Unknown type annotation"): 138 torch.jit.script(wrong_key_type) 139 140 # Check for invalid key and value type annotation 141 def wrong_key_value_type( 142 dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule] 143 ): 144 return 145 146 with self.assertRaisesRegex(ValueError, "Unknown type annotation"): 147 torch.jit.script(wrong_key_value_type) 148 149 def test_tuple_specialization(self): 150 @torch.jit.script 151 def f(t, s): 152 # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor 153 x, t2 = t 154 _, y = t2 155 return x + y 156 157 t = ( 158 torch.randn(2, 2), 159 (1, torch.randn(2, 2)), 160 ) 161 f(t, "hi") 162 graph = f.graph_for(t, "hi") 163 input_types = list(next(graph.inputs()).type().elements()) 164 w = input_types[0] 165 self.assertEqual(input_types[0].kind(), "TensorType") 166 self.assertEqual(input_types[1].elements()[1].kind(), "TensorType") 167 168 def test_tuple_io(self): 169 def stuff(x): 170 # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] 171 a, b = x 172 return b, a 173 174 a = (torch.rand(3), torch.rand(3)) 175 self.checkScript(stuff, (a,)) 176 177 def test_tuple_keyword(self): 178 def bar(): 179 f = tuple((1, 2)) # noqa: C409 180 return f 181 182 self.checkScript(bar, ()) 183 184 def foo(): 185 return tuple(1, 2) 186 187 self.checkScriptRaisesRegex(foo, (), Exception, "1 argument") 188 189 def cant_infer_size(): 190 return tuple([1, 2, 3]) # noqa: C409 191 192 with self.assertRaisesRegex(Exception, "cannot statically infer the expected"): 193 torch.jit.script(cant_infer_size) 194 195 def test_tuple_create_return(self): 196 def stuff2(x): 197 # type: (int) -> Tuple[Tensor, Tensor] 198 a = (torch.ones(x), torch.zeros(x)) 199 return a 200 201 self.checkScript(stuff2, (3,)) 202 203 def test_list_io(self): 204 def stuff3(x): 205 # type: (List[int]) -> Tuple[Tensor, List[int]] 206 return torch.ones(x), x 207 208 self.checkScript(stuff3, ([3, 2],)) 209 210 def test_bool_list_io(self): 211 @torch.jit.script 212 def stuff4(x): 213 # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]] 214 return x, [True, False], [[True]] 215 216 li_1, li_2, li_3 = stuff4([True]) 217 li_3 = li_3[0] 218 for li in [li_1, li_2, li_3]: 219 self.assertTrue(type(li[0]) == bool) 220 221 def test_nested_list(self): 222 def foo(z): 223 # type: (Tuple[int, List[List[int]]]) -> int 224 x, y = z 225 return y[0][1] 226 227 self.checkScript(foo, ((1, [[1, 2], [3, 4]]),)) 228 229 def test_list_sum(self): 230 def fn(x: List[int]) -> int: 231 return sum(x) 232 233 def fn1(x: List[float]): 234 return sum(x) 235 236 def fn2(x: List[bool]): 237 return sum(x) 238 239 self.checkScript(fn, ([1, 2, 3],)) 240 self.checkScript(fn1, ([1.0, 2.0, 3.0],)) 241 self.checkScript(fn1, ([1, 2.8, 3],)) 242 self.checkScript(fn2, ([True, False, False],)) 243 self.checkScript(fn2, ([False, False, False],)) 244 self.checkScript(fn2, ([0, 1, 1, 0],)) 245 246 def test_list_unification(self): 247 def fn(): 248 return [1, None, 2] 249 250 def fn2(x): 251 return [torch.ones(2, 2), None, x] 252 253 self.checkScript(fn, []) 254 self.checkScript(fn2, (torch.ones(2, 2),)) 255 256 # to avoid defining sum_list in multiple tests 257 def get_sum_list_fn(self): 258 def sum_list(a): 259 # type: (List[int]) -> int 260 sum = 0 261 for i in a: 262 sum += i 263 264 return sum 265 266 return sum_list 267 268 def test_sum_list_diff_elms(self): 269 self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) 270 271 def test_sum_list_empty(self): 272 self.checkScript(self.get_sum_list_fn(), ([],)) 273 274 def test_sum_list_one(self): 275 self.checkScript(self.get_sum_list_fn(), ([1],)) 276 277 def test_sum_list_literal(self): 278 def sum_list(): 279 # type: () -> int 280 sum = 0 281 for i in [1, 2, 3, 4, 5]: 282 sum += i 283 284 return sum 285 286 self.checkScript(sum_list, ()) 287 288 def test_sum_list_wrong_type(self): 289 with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): 290 291 @torch.jit.script 292 def sum_list(a): 293 # type: (int) -> int 294 sum = 0 295 for i in a: # noqa: T484 296 sum += i 297 298 return sum 299 300 sum_list(1) 301 302 def test_list_iterables(self): 303 with self.assertRaisesRegex( 304 RuntimeError, "List of iterables is not supported currently" 305 ): 306 cu = torch.jit.CompilationUnit( 307 """ 308 def list_iterables(x): 309 for i, j in [2, 3, 4], [5, 6, 7]: 310 x += i 311 x += j 312 return x 313 """ 314 ) 315 316 def test_for_in_string(self): 317 def test_strings(x): 318 # type: (str) -> str 319 reverse = "" 320 for c in x: 321 reverse = c + reverse 322 return reverse 323 324 self.checkScript(test_strings, ("hello",)) 325 self.checkScript(test_strings, ("",)) 326 327 def test_list_strings(x): 328 # type: (List[str]) -> str 329 result = "" 330 for sub_str in x: 331 result += sub_str 332 return result 333 334 self.checkScript(test_list_strings, (["hello", "world"],)) 335 self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) 336 337 def test_for_in_dict(self): 338 def test_dicts(x): 339 # type: (Dict[str, int]) -> int 340 sum = 0 341 for key in x: 342 sum += x[key] 343 return sum 344 345 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 346 347 def test_dict_keys_values(x): 348 # type: (Dict[str, int]) -> Tuple[str, int] 349 key_str = "" 350 sum = 0 351 for key in x.keys(): 352 key_str += key 353 for val in x.values(): 354 sum += val 355 return key_str, sum 356 357 self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 358 359 def test_for_tuple_unpack(self): 360 def for_tuple_unpack(x, y): 361 for i, j in [[3, 4], [5, 6], [7, 8]]: 362 x += i 363 y += j 364 return x, y 365 366 self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) 367 368 def nested_tuple_unpack(x, y): 369 # type: (List[int], List[int]) -> int 370 sum = 0 371 for i, (j, k), v in zip(x, enumerate(x), y): 372 sum += i + j + k + v 373 return sum 374 375 self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) 376 377 def test_dict_comprehension(self): 378 def fn(): 379 return {i: chr(i + 65) for i in range(4)} 380 381 self.checkScript(fn, ()) 382 383 def test_dict_comprehension_with_type_annotation(self): 384 def fn(): 385 d: Dict[int, str] = {i: chr(i + 65) for i in range(4)} 386 return d 387 388 self.checkScript(fn, ()) 389 390 with self.assertRaisesRegex(RuntimeError, ""): 391 with self.assertRaisesRegex( 392 AssertionError, 393 "Expected Dict " 394 "type annotation for dict " 395 "comprehension, found " 396 "Tuple[int, str]", 397 ): 398 399 @torch.jit.script 400 def fn(): 401 d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)} 402 return d 403 404 def test_dict_comprehension_scope(self): 405 def comprehension_can_access_outer_scope_variables(): 406 lst = ["foo", "bar", "baz"] 407 return {l: len(l) for l in lst} 408 409 self.checkScript(comprehension_can_access_outer_scope_variables, ()) 410 411 with self.assertRaisesRegex(RuntimeError, "undefined value i"): 412 413 @torch.jit.script 414 def outer_scope_cannot_access_comprehension_variables(): 415 d = {i: chr(i + 65) for i in range(4)} 416 i = i + 1 # noqa: F821 417 418 def test_for_tuple_assign(self): 419 def test_simple_assign(x): 420 # type: (Tuple[int, float]) -> float 421 sum = 0.0 422 for a in x: 423 sum += float(a) 424 return sum 425 426 self.checkScript(test_simple_assign, ((1, 2.5),)) 427 428 def test_tuple_assign(x): 429 # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int 430 sum = 0 431 for a in x: 432 sum += a[0] 433 sum += a[1] 434 return sum 435 436 self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),)) 437 438 def test_single_starred_lhs(self): 439 with self.assertRaisesRegex( 440 RuntimeError, 441 "A Starred expression may only appear on the lhs within the presence" 442 " of another non-starred expression", 443 ): 444 cu = torch.jit.CompilationUnit( 445 """ 446 def single_starred_lhs(x): 447 a = (x, x, x) 448 *b, = a 449 return b 450 """ 451 ) 452 453 def test_singleton_tuple_unpack(self): 454 def foo(a): 455 (b,) = (a,) 456 return b + 1 457 458 self.checkScript(foo, (torch.rand(3),)) 459 460 def test_tuple_assignments(self): 461 def var_tuple_assign(x, y): 462 # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor 463 (a, b), c = x, y 464 return a + b + c 465 466 tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) 467 self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) 468 469 def nested_tuple_assign(x, y, z): 470 # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int 471 a, (b, (c, d)), (e, f) = x, y, z 472 return a + b + c + d + e + f 473 474 self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) 475 476 def subscript_tuple_assign(a, x, i): 477 # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] 478 a[i], (x[i], b) = 1, (2, 3) 479 return a[i] + 1, x + 5, b 480 481 self.checkScript( 482 subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0) 483 ) 484 485 def star_tuple_assign(): 486 # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] 487 a, (b, *c), *d = 1, (2, 3, 4), 5, 6 488 return a, b, c, d 489 490 self.checkScript(star_tuple_assign, ()) 491 492 def subscript_tuple_augmented_assign(a): 493 # type: (Tuple[int, int]) -> Tuple[int, int] 494 a[0] += 1 495 return a 496 497 with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"): 498 scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) 499 500 def test_multiple_assign(self): 501 def test(): 502 a = b, c = d, f = (1, 1) 503 504 # side effect 505 ten = torch.tensor(1) 506 ten1 = ten2 = ten.add_(1) 507 508 # ordering 509 x = 1 510 y = 3 511 x, y = y, x + y 512 513 return a, b, c, d, f, ten, ten1, ten2, x, y 514 515 self.checkScript(test, ()) 516 517 def test_opt_opt_refinement(self): 518 @torch.jit.script 519 def test_unify(weight, bias): 520 # type: (Optional[int], Optional[int]) -> Optional[int] 521 if weight is not None: 522 opt = None 523 else: 524 if bias is not None: 525 opt = 1 526 else: 527 opt = None 528 529 return opt 530 531 def test_optional_refinement(self): 532 @torch.jit.script 533 def test_if_none_assignment(x): 534 # type: (Optional[int]) -> int 535 if x is None: 536 x = 1 537 return x + 1 538 539 self.assertEqual(test_if_none_assignment(1), 2) 540 541 def test_optional_conversion(self): 542 @torch.jit.script 543 def other_fn(x=None): 544 # type: (Optional[int]) -> int 545 return torch.jit._unwrap_optional(x) 546 547 @torch.jit.script 548 def fn(x): 549 # type: (int) -> int 550 return other_fn(x) 551 552 self.assertEqual(fn(2), 2) 553 554 @torch.jit.script 555 def unify_to_optional(x): 556 # type: (bool) -> Optional[int] 557 if x: 558 a = None 559 else: 560 a = 2 561 return a 562 563 self.assertEqual(unify_to_optional(True), None) 564 self.assertEqual(unify_to_optional(False), 2) 565 566 @torch.jit.script 567 def opt_list(x): 568 # type: (Optional[List[float]]) -> int 569 return 2 570 571 @torch.jit.script 572 def broadcast_opt_list(x): 573 # type: (Optional[BroadcastingList2[float]]) -> int 574 return 2 575 576 @torch.jit.script 577 def opt_list_tuple_caller(x): 578 # type: (Tuple[float, float]) -> int 579 return opt_list(x) + broadcast_opt_list(x) 580 581 self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4) 582 583 def test_optional_tuple(self): 584 def fn(x=None): 585 # type: (Optional[Tuple[int, int]]) -> Tuple[int, int] 586 if x is None: 587 new_x = (1, 2) 588 else: 589 new_x = x 590 return new_x 591 592 self.checkScript(fn, ((3, 4),)) 593 self.checkScript(fn, ()) 594 595 def test_namedtuple_redefine(self): 596 global _1, _2 597 _1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) 598 _2 = namedtuple("GoogLeNetOutputs", ["different"]) 599 600 with self.assertRaisesRegex(RuntimeError, r"redefine"): 601 602 @torch.jit.script 603 def foo(x, y): 604 # type: (_1, _2) -> _1 605 return x 606 607 def test_namedtuple_py2(self): 608 global _GoogLeNetOutputs # see [local resolution in python] 609 _GoogLeNetOutputs = namedtuple( 610 "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"] 611 ) 612 613 @torch.jit.script 614 def foo(x): 615 # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs 616 return x 617 618 vals = torch.rand(3), torch.rand(4), torch.rand(5) 619 out = foo( 620 _GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]) 621 ) 622 self.assertEqual(out.logits, vals[0]) 623 self.assertEqual(out.aux_logits2, vals[1]) 624 self.assertEqual(out.aux_logits1, vals[2]) 625 626 def test_namedtuple_good_error(self): 627 global _GoogLeNetOutputs # see [local resolution in python] 628 _GoogLeNetOutputs = namedtuple( 629 "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"] 630 ) 631 632 @torch.jit.script 633 def foo(x): 634 # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs 635 return x 636 637 with self.assertRaisesRegex( 638 RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)" 639 ): 640 out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5")) 641 642 def test_namedtuple_error_source_attribution(self): 643 class _NamedTupleBadMemberType(NamedTuple): 644 f1: torch.Tensor 645 f2: "ABadForwardRefType" # noqa: F821 646 647 make_global(_NamedTupleBadMemberType) # see [local resolution in python] 648 649 def fn(x: _NamedTupleBadMemberType) -> torch.Tensor: 650 return x.f1.relu() 651 652 # assert that this has a location associated with the error. 653 # note the " +" is regex (i.e. "at least one space") 654 with self.assertRaisesRegex(ValueError, "at +File"): 655 torch.jit.script(fn) 656 657 def test_inherited_annotations_python_310(self): 658 # See #104484 659 # In python >=3.10, inspect.get_annotations doesn't always return the same values. 660 # Sometimes it will show all annotations; other times it will show only annotations 661 # that show in that class, not classes it inherits fro. 662 class BaseModule(torch.nn.Module): 663 state: List[int] 664 665 def forward(self, x): 666 pass 667 668 def do_something_with_list(x: List[int]): 669 if x: 670 return x[-1] 671 return 5 672 673 class Submodule(BaseModule): 674 def __init__(self, self_x_value): 675 super().__init__() 676 self.x = self_x_value 677 self.state = [] 678 679 def forward(self, x): 680 return self.x + x + do_something_with_list(self.state) 681 682 class LowestModule(Submodule): 683 def __init__(self) -> None: 684 super().__init__(123) 685 686 mod = LowestModule() 687 mod2 = LowestModule() 688 mod_s = torch.jit.script(mod) 689 mod2_s = torch.jit.script(mod2) 690