1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6from enum import Enum 7from textwrap import dedent 8from typing import Dict, List, Optional, Tuple, Union 9 10import torch 11from torch.testing import FileCheck 12 13 14# Make the helper files in test/ importable 15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16sys.path.append(pytorch_test_dir) 17from torch.testing._internal.jit_utils import JitTestCase, make_global 18 19 20if __name__ == "__main__": 21 raise RuntimeError( 22 "This test file is not meant to be run directly, use:\n\n" 23 "\tpython test/test_jit.py TESTNAME\n\n" 24 "instead." 25 ) 26 27 28class TestUnion(JitTestCase): 29 """ 30 This class tests the functionality of `Union`. 31 32 Note: It's important to be able to refine the type of a `Union` to 33 one of its internal types. Currently, there are differences in the 34 way Python expects `isinstance` checks and the way TorchScript 35 expects `isinstance` checks. This means that we can't use 36 `checkScript` in our test cases because either the eager mode or the 37 script mode wouldn't run! So, some test cases have separate but 38 equivalent functions to emulate `checkScript`. 39 """ 40 41 def test_check_union_annotation(self): 42 def test_func(a: Union[int, float], b: Optional[int]): 43 return 0 44 45 scripted_func = torch.jit.script(test_func) 46 graph_rep = str(scripted_func.graph) 47 code_rep = str(scripted_func.code) 48 # TS graph IR for Union should be annotated as Union() 49 FileCheck().check("Union(").check("int?").run(graph_rep) 50 # Serialized code for Union should be annotated as Union[] 51 FileCheck().check("Union[").check("Optional[int]").run(code_rep) 52 self.checkScript(test_func, (5, 6)) 53 # this shouldn't error out 54 torch._C.parse_ir(str(scripted_func.graph)) 55 56 def test_union_with_scalar_values(self): 57 def fn(x: Union[int, float]) -> str: 58 return "foo" 59 60 self.checkScript(fn, (1,)) 61 self.checkScript(fn, (1.0,)) 62 63 scripted = torch.jit.script(fn) 64 65 with self.assertRaisesRegex( 66 RuntimeError, 67 "Expected a member of" 68 r" Union\[float, int\] but " 69 "instead found type str", 70 ): 71 scripted("1") 72 73 def test_union_with_collections(self): 74 def fn(x: Union[Dict[str, int], List[int]]) -> str: 75 return "foo" 76 77 self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 78 self.checkScript(fn, ([1, 2, 3],)) 79 80 scripted = torch.jit.script(fn) 81 82 with self.assertRaisesRegex( 83 RuntimeError, 84 "Expected a member of" 85 r" Union\[List\[int\], Dict\[str, " 86 r"int\]\] but instead found type " 87 r"Dict\[str, str\]", 88 ): 89 scripted({"foo": "bar", "baz": "qux"}) 90 91 with self.assertRaisesRegex( 92 RuntimeError, 93 "Expected a member of" 94 r" Union\[List\[int\], Dict\[str, " 95 r"int\]\] but instead found type " 96 r"List\[str\]", 97 ): 98 scripted(["foo", "bar", "baz"]) 99 100 with self.assertRaisesRegex( 101 RuntimeError, 102 "Expected a member of" 103 r" Union\[List\[int\], Dict\[str, " 104 r"int\]\] but instead found type " 105 "str", 106 ): 107 scripted("1") 108 109 def test_union_with_enum(self): 110 class Color(Enum): 111 RED = 1 112 GREEN = 2 113 114 make_global(Color) 115 116 def fn(x: Union[str, Color]) -> str: 117 return "foo" 118 119 self.checkScript(fn, (Color.RED,)) 120 self.checkScript(fn, ("red",)) 121 122 scripted = torch.jit.script(fn) 123 124 with self.assertRaisesRegex( 125 RuntimeError, 126 "Expected a member of" 127 r" Union\[__torch__.jit.test_union." 128 r"Color, str\] but instead found " 129 "type int", 130 ): 131 scripted(1) 132 133 def test_union_in_class_constructor(self): 134 @torch.jit.script # noqa: B903 135 class A: # noqa: B903 136 def __init__(self, x: Union[int, str]) -> None: 137 self.x = x 138 139 def fn(x: Union[str, int]) -> A: 140 return A(x) 141 142 self.assertEqual(fn("foo").x, "foo") 143 self.assertEqual(fn(1).x, 1) 144 145 scripted = torch.jit.script(fn) 146 147 with self.assertRaisesRegex( 148 RuntimeError, 149 "Expected a member of" 150 r" Union\[int, str\] but instead " 151 r"found type List\[str\]", 152 ): 153 scripted(["foo", "bar", "baz"]) 154 155 def test_union_return_type(self): 156 def fn(x: int) -> Union[int, str]: 157 return "foo" 158 159 self.checkScript(fn, (1,)) 160 161 def test_union_as_annotation(self): 162 def fn() -> Union[int, str]: 163 x: Union[int, str] = "foo" 164 return x 165 166 self.checkScript(fn, ()) 167 168 def test_union_as_annotation_in_typed_container(self): 169 def fn() -> None: 170 l: List[Union[int, str]] = [] 171 u1: Union[int, str] = "foo" 172 u2: Union[int, str] = 1 173 l.append(u1) 174 l.append(u2) 175 176 self.checkScript(fn, ()) 177 178 def test_union_as_annotation_py2(self): 179 def fn(): 180 # type: () -> Union[int, str] 181 x: Union[int, str] = "foo" 182 return x 183 184 self.checkScript(fn, ()) 185 186 def test_union_as_internal_tuple_type(self): 187 def fn(): 188 t: Tuple[Union[int, str], Union[int, str]] = (1, "foo") 189 return t 190 191 self.checkScript(fn, ()) 192 193 def test_union_variable_can_be_reassigned(self): 194 @torch.jit.script 195 def aux1(i: int): 196 return int(i**2) 197 198 @torch.jit.script 199 def aux2(s: str): 200 return s + s 201 202 def fn() -> Union[int, str]: 203 x: Union[int, str] = "foo" 204 i: int = 1 205 x = i 206 y: int = aux1(x) 207 z: str = aux2(str(y)) 208 x = z 209 return x 210 211 self.checkScript(fn, ()) 212 213 def test_union_does_not_replace_existing_annotated_type(self): 214 def fn(): 215 x: List[int] = [1, 2, 3] 216 x.append("foo") 217 return x 218 219 with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 220 scripted = torch.jit.script(fn) 221 scripted() 222 223 def test_union_does_not_replace_existing_annotated_type_union(self): 224 def fn(): 225 x: List[Union[int, str]] = [1, "foo", 3] 226 x.append(2.0) 227 return x 228 229 with self.assertRaisesRegex(RuntimeError, "Could not match type float"): 230 scripted = torch.jit.script(fn) 231 scripted() 232 233 def test_union_does_not_replace_existing_annotated_type_empty_container(self): 234 def fn(): 235 x: List[int] = [] 236 x.append("foo") 237 return x 238 239 with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 240 scripted = torch.jit.script(fn) 241 scripted() 242 243 def test_unions_of_unions_are_flattened(self): 244 @torch.jit.script 245 def fn(x: Union[Union[int, str], float]) -> str: 246 return "foo" 247 248 s = fn.graph 249 250 FileCheck().check("x : Union(float, int, str)").run(s) 251 252 def test_unions_of_a_single_argument_vanish(self): 253 @torch.jit.script 254 def fn(x: Union[int]) -> str: 255 return "foo" 256 257 s = fn.graph 258 259 FileCheck().check("x : int").run(s) 260 261 def test_union_redundant_arguments_are_skipped(self): 262 @torch.jit.script 263 def fn(x: Union[int, str, int]) -> str: 264 return "foo" 265 266 s = fn.graph 267 268 FileCheck().check("x : Union(int, str)").run(s) 269 270 def test_union_redundant_arguments_are_skipped_optional(self): 271 @torch.jit.script 272 def fn(x: Union[int, Optional[float], Optional[int]]) -> str: 273 return "foo" 274 275 s = fn.graph 276 277 FileCheck().check("x : Union(float, int, NoneType)").run(s) 278 279 def test_union_redundant_arguments_are_skipped_subtyping(self): 280 @torch.jit.script 281 def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str: 282 return "foo" 283 284 s = fn.graph 285 286 FileCheck().check("x : Union((int?, int), str)").run(s) 287 288 def test_union_redundant_arguments_are_skipped_container(self): 289 @torch.jit.script 290 def fn(x: Union[List[str], List[float], List[str]]) -> str: 291 return "foo" 292 293 s = fn.graph 294 295 FileCheck().check("x : Union(float[], str[])").run(s) 296 297 def test_union_argument_order_is_ignored(self): 298 @torch.jit.script 299 def fn1(x: Union[int, str]) -> str: 300 return "foo" 301 302 @torch.jit.script 303 def fn2(x: Union[str, int]) -> str: 304 return "foo" 305 306 for s in (fn1.graph, fn2.graph): 307 FileCheck().check("x : Union(int, str)").run(s) 308 309 def test_union_argument_order_is_ignored_container(self): 310 @torch.jit.script 311 def fn1(x: Union[List[str], List[int]]) -> str: 312 return "foo" 313 314 @torch.jit.script 315 def fn2(x: Union[List[int], List[str]]) -> str: 316 return "foo" 317 318 for s in (fn1.graph, fn2.graph): 319 FileCheck().check("x : Union(int[], str[])").run(s) 320 321 def test_union_T_None_is_equivalent_to_optional_T(self): 322 @torch.jit.script 323 def inner(x: Union[int, None]) -> int: 324 if x is not None: 325 return x 326 else: 327 return 5 328 329 @torch.jit.script 330 def fn1() -> int: 331 a: Optional[int] = 5 332 b: Optional[int] = None 333 a_ = inner(a) 334 b_ = inner(b) 335 return a_ + b_ 336 337 self.assertEqual(fn1(), 10) 338 339 @torch.jit.script 340 def inner2(x: Optional[int]) -> int: 341 if x is not None: 342 return x 343 else: 344 return 5 345 346 @torch.jit.script 347 def fn2() -> int: 348 a: Union[int, None] = 5 349 b: Union[int, None] = None 350 a_ = inner(a) 351 b_ = inner(b) 352 return a_ + b_ 353 354 self.assertEqual(fn2(), 10) 355 356 def test_union_optional_of_union_is_flattened(self): 357 @torch.jit.script 358 def fn(flag: int) -> Union[str, int, None]: 359 y: Union[int, str, None] = "foo" 360 if flag == 0: 361 x: Optional[Union[int, str]] = y 362 elif flag == 1: 363 x: Optional[Union[int, str]] = 1 364 else: 365 x: Optional[Union[int, str]] = None 366 return x 367 368 # Can't use `checkScript` because it will flag the fact that 369 # the original code has `Optional[Union[int, str]]` but the 370 # saved/loaded code has `Union[int, NoneType, str]` (even 371 # though this is exactly what we want) 372 self.assertEqual(fn(0), "foo") 373 self.assertEqual(fn(1), 1) 374 self.assertEqual(fn(2), None) 375 376 buffer = io.BytesIO() 377 torch.jit.save(fn, buffer) 378 buffer = io.BytesIO(buffer.getvalue()) 379 l = torch.jit.load(buffer) 380 381 s = l.code 382 383 FileCheck().check("Union[int, NoneType, str]").check( 384 "Union[int, NoneType, str]" 385 ).run(s) 386 387 def test_union_subclasses_larger_union(self): 388 def fn() -> Union[int, str, torch.Tensor]: 389 x: Union[int, str] = "foo" 390 return x 391 392 self.checkScript(fn, ()) 393 394 # TODO: We would like to eventually support this. The issue is being 395 # tracked at https://github.com/pytorch/pytorch/issues/58167 396 def test_union_as_dict_key(self): 397 def fn(): 398 x: Dict[Union[int, str], str] = {} 399 x["foo"] = "bar" 400 x[1] = 2 401 return x[1] 402 403 with self.assertRaisesRegex( 404 RuntimeError, 405 "only int, float, " 406 "complex, Tensor, device and string keys " 407 "are supported", 408 ): 409 torch.jit.script(fn) 410 411 def test_union_as_dict_value(self): 412 def fn(): 413 x: Dict[str, Union[int, str]] = {} 414 x["foo"] = "bar" 415 x["baz"] = 2 416 return x["baz"] 417 418 self.checkScript(fn, ()) 419 420 def test_union_module_with_union_instance_variable(self): 421 class M(torch.nn.Module): 422 x: Union[int, str] 423 424 def __init__(self, x: Union[int, str]): 425 super().__init__() 426 self.x: Union[int, str] = x 427 428 def forward(self, y: Union[int, str]): 429 self.x = y 430 return self.x 431 432 self.checkModule( 433 M( 434 2, 435 ), 436 (1,), 437 ) 438 self.checkModule(M("bar"), ("foo",)) 439 440 def test_union_module_with_union_class_variable(self): 441 class M(torch.nn.Module): 442 x: Union[int, str] = "foo" 443 444 def __init__(self, y: int): 445 super().__init__() 446 x = y 447 448 def forward(self, z: str): 449 x = z 450 return x 451 452 self.checkModule(M(1), ("foo",)) 453 454 def test_union_type_refinement(self): 455 def fn(x: Union[int, str]) -> str: 456 if isinstance(x, str): 457 z = x + "bar" 458 return x 459 else: 460 return "baz" 461 462 self.checkScript(fn, ("foo",)) 463 self.checkScript(fn, (1,)) 464 465 def test_union_type_refinement_union_rhs(self): 466 def fn(x: int) -> str: 467 if torch.jit.isinstance(x, Union[int, str]): 468 return "bar" 469 else: 470 return "baz" 471 472 self.checkScript(fn, (1,)) 473 474 def test_union_type_refinement_tuple_rhs(self): 475 def fn(x: Union[int, float, List[str]]) -> str: 476 if isinstance(x, (int, float)): 477 if isinstance(x, int): 478 return str(x) 479 else: 480 return "foo" 481 else: 482 if len(x): 483 return x[0] 484 else: 485 return "bar" 486 487 self.checkScript(fn, (1,)) 488 self.checkScript(fn, (1.0,)) 489 self.checkScript(fn, (["a", "b", "c"],)) 490 491 def test_union_type_refinement_tuple_rhs_noncontained_type(self): 492 def fn(x: Union[int, List[str]]) -> str: 493 if isinstance(x, (int, float)): 494 y = x + x 495 return str(y) 496 else: 497 if len(x): 498 return x[0] 499 else: 500 return "bar" 501 502 self.checkScript(fn, (1,)) 503 self.checkScript(fn, (["a", "b", "c"],)) 504 505 def test_union_type_refinement_tuple_rhs_union(self): 506 @torch.jit.script 507 def fn(x: int) -> str: 508 if torch.jit.isinstance(x, (Union[int, str], float)): 509 y = x + x 510 return str(y) 511 else: 512 return "foo" 513 514 # TODO: There's currently an unrelated bug in 515 # `torch.jit.isinstance` that makes it fail for tuple literals. 516 # Posted here: https://github.com/pytorch/pytorch/issues/60095 517 # Change `assertEqual` to `checkScript` when the bug is fixed 518 self.assertEqual(fn(1), "2") 519 520 def test_union_type_refinement_statically_false(self): 521 @torch.jit.script 522 def fn(x: int) -> str: 523 if torch.jit.isinstance(x, (Union[str, float], List[str], str)): 524 z = x + "foo" 525 return z 526 else: 527 return "bar" 528 529 s = fn.graph 530 531 # Check that we don't have any branching statements 532 FileCheck().check_not("block0()").check_not("block1()").run(s) 533 534 def test_union_type_refinement_statically_true(self): 535 @torch.jit.script 536 def fn(x: Union[List[int], int]) -> Union[List[int], int]: 537 if not torch.jit.isinstance(x, (int, List[int])): 538 return x 539 else: 540 l = [1, 2, 3] 541 y: Union[List[int], int] = l 542 return y 543 544 s = fn.graph 545 546 # Check that we don't have any branching statements 547 FileCheck().check_not("block0()").check_not("block1()").run(s) 548 549 def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): 550 def fn(x: Union[List[int], int]) -> int: 551 if torch.jit.isinstance(x, (int, float, str)): 552 # We should know that `x` is an `int` here 553 z = x + 1 554 return z 555 else: 556 return 100 557 558 self.checkScript(fn, ([1, 2, 3],)) 559 self.checkScript(fn, (1,)) 560 561 def test_union_type_refinement_partial_static_refinement_union_rhs(self): 562 def fn(x: Union[List[int], int]) -> int: 563 if torch.jit.isinstance(x, Union[int, float, str]): 564 # We should know that `x` is an `int` here 565 z = x + 1 566 return z 567 else: 568 return 100 569 570 self.checkScript(fn, ([1, 2, 3],)) 571 self.checkScript(fn, (1,)) 572 573 def test_union_type_refinement_internal_declaration(self): 574 def fn(flag: bool) -> str: 575 x: Union[int, str, None] = None 576 if flag: 577 y = "foo" 578 else: 579 y = 1 580 if isinstance(x, str): 581 return x 582 else: 583 return "bar" 584 585 self.checkScript(fn, (True,)) 586 self.checkScript(fn, (False,)) 587 588 def test_union_branching_with_union_return_and_homogenous_types(self): 589 def fn(x: int) -> Union[int, str]: 590 if x % 2: 591 return "foo" 592 else: 593 return "bar" 594 595 self.checkScript(fn, (1,)) 596 self.checkScript(fn, (8,)) 597 598 def test_union_branching_does_not_autoinfer_undeclared_union(self): 599 def fn(x: int) -> str: 600 if x % 2: 601 y = "foo" 602 else: 603 y = x 604 if isinstance(y, str): 605 return y 606 else: 607 return "bar" 608 609 with self.assertRaisesRegex( 610 RuntimeError, 611 "y is set to type str" 612 " in the true branch and type int " 613 "in the false branch", 614 ): 615 torch.jit.script(fn) 616 617 def test_union_branching_does_not_widen_existing_inferred_type(self): 618 def fn(x: int) -> str: 619 y = "foo" 620 if x % 2: 621 y = "bar" 622 else: 623 y = x 624 if isinstance(y, str): 625 return y 626 else: 627 return "baz" 628 629 with self.assertRaisesRegex( 630 RuntimeError, 631 "previously had type " 632 "str but is now being assigned to a" 633 " value of type int", 634 ): 635 torch.jit.script(fn) 636 637 def test_union_schema_matching_on_internal_type(self): 638 def fn(x: Union[List[int], Dict[str, int]]) -> int: 639 if torch.jit.isinstance(x, List[int]): 640 return x[0] 641 else: 642 return list(x.values())[0] 643 644 self.checkScript(fn, ([1, 2, 3],)) 645 self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 646 647 def test_union_subtractive_refinement(self): 648 def fn(x: Union[List[int], int]) -> int: 649 if not isinstance(x, int): 650 x.append(1) 651 return x[0] 652 else: 653 return x 654 655 self.checkScript(fn, (1,)) 656 self.checkScript(fn, ([1, 2, 3],)) 657 658 def test_union_subtractive_refinement_with_container(self): 659 def fn(x: Union[List[int], int]) -> int: 660 if not torch.jit.isinstance(x, List[int]): 661 return x 662 else: 663 x.append(1) 664 return x[0] 665 666 self.checkScript(fn, (1,)) 667 self.checkScript(fn, ([1, 2, 3],)) 668 669 def test_union_memory_aliasing(self): 670 def fn(): 671 x: List[torch.Tensor] = [] 672 z: List[Optional[List[torch.Tensor]]] = [] 673 z.append(x) 674 x_alias = z[0] 675 if torch.jit.isinstance(x_alias, List[torch.Tensor]): 676 x_alias.append(torch.tensor(3)) 677 return x 678 679 self.checkScript(fn, ()) 680 681 def test_union_serialization_preserves_type_annotations(self): 682 # This function will fail after being torch.jit.save'd and 683 # torch.jit.load'd if the type annotations aren't preserved 684 # for Union during serialization. We need the `Union[str, int]` 685 # annotation to make sure that `y` is typed as a Union instead 686 # of as a str in one branch and an int in the other 687 def fn(x: int) -> str: 688 if x % 2: 689 y: Union[str, int] = "bar" 690 else: 691 y: Union[str, int] = x 692 if isinstance(y, str): 693 return y 694 else: 695 return "baz" 696 697 self.checkScript(fn, (1,)) 698 self.checkScript(fn, (8,)) 699 700 def _assert_passes(self, template: str, ann: str, lhs: str): 701 code = template.format(ann=ann, lhs=lhs) 702 self.checkScript(code, (), name="fn") 703 704 def _assert_raises(self, template: str, ann: str, lhs: str, msg: str): 705 code = template.format(ann=ann, lhs=lhs) 706 with self.assertRaisesRegex(RuntimeError, msg): 707 cu = torch.jit.CompilationUnit(code, _frames_up=1) 708 string_frontend = getattr(cu, "fn") # noqa: B009 709 710 def test_union_with_list_assignment(self): 711 template = dedent( 712 """ 713 def fn(): 714 x: {ann} = {lhs} 715 if torch.jit.isinstance(x, List[torch.Tensor]): 716 x.append(torch.tensor(3)) 717 return x 718 """ 719 ) 720 721 lhs = { 722 "list_literal_empty": "[]", 723 "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]", 724 "list_literal_of_str": '["foo", "bar", "baz"]', 725 "list_literal_of_mixed": "[torch.arange(5), 1]", 726 "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]", 727 "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]', 728 "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]", 729 } 730 731 """ 732 Union[List[str], List[torch.Tensor]] 733 """ 734 self._assert_raises( 735 template, 736 "Union[List[str], List[torch.Tensor]]", 737 lhs["list_literal_empty"], 738 "there are multiple possible List type " 739 "candidates in the Union annotation", 740 ) 741 742 self._assert_passes( 743 template, 744 "Union[List[str], List[torch.Tensor]]", 745 lhs["list_literal_of_tensor"], 746 ) 747 748 self._assert_passes( 749 template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"] 750 ) 751 752 self._assert_raises( 753 template, 754 "Union[List[str], List[torch.Tensor]]", 755 lhs["list_literal_of_mixed"], 756 "none of those types match the types of the" " given list elements", 757 ) 758 759 self._assert_passes( 760 template, 761 "Union[List[str], List[torch.Tensor]]", 762 lhs["list_comprehension_of_tensor"], 763 ) 764 765 self._assert_passes( 766 template, 767 "Union[List[str], List[torch.Tensor]]", 768 lhs["list_comprehension_of_str"], 769 ) 770 771 # TODO: Support mixed list comprehensions 772 self._assert_raises( 773 template, 774 "Union[List[str], List[torch.Tensor]]", 775 lhs["list_comprehension_of_mixed"], 776 "Arguments for call are not valid", 777 ) 778 779 """ 780 Union[int, torch.Tensor] 781 """ 782 self._assert_raises( 783 template, 784 "Union[int, torch.Tensor]", 785 lhs["list_literal_empty"], 786 "Expected an Union type annotation with an " "inner List type", 787 ) 788 789 self._assert_raises( 790 template, 791 "Union[int, torch.Tensor]", 792 lhs["list_literal_of_tensor"], 793 "Expected an Union type annotation with an " "inner List type", 794 ) 795 796 self._assert_raises( 797 template, 798 "Union[int, torch.Tensor]", 799 lhs["list_comprehension_of_tensor"], 800 "Expected an Union type annotation with an " "inner List type", 801 ) 802 803 """ 804 Union[List[torch.Tensor], int] 805 """ 806 self._assert_passes( 807 template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"] 808 ) 809 810 self._assert_passes( 811 template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"] 812 ) 813 814 self._assert_raises( 815 template, 816 "Union[List[torch.Tensor], int]", 817 lhs["list_literal_of_str"], 818 r"List type annotation `List\[Tensor\]` did " 819 "not match the types of the given list " 820 "elements", 821 ) 822 823 self._assert_raises( 824 template, 825 "Union[List[torch.Tensor], int]", 826 lhs["list_literal_of_mixed"], 827 r"List type annotation `List\[Tensor\]` did " 828 "not match the types of the given list " 829 "elements", 830 ) 831 832 self._assert_passes( 833 template, 834 "Union[List[torch.Tensor], int]", 835 lhs["list_comprehension_of_tensor"], 836 ) 837 838 self._assert_raises( 839 template, 840 "Union[List[torch.Tensor], int]", 841 lhs["list_comprehension_of_str"], 842 r"List type annotation `List\[Tensor\]` did " 843 "not match the types of the given list " 844 "elements", 845 ) 846 847 # TODO(@ansley): Support mixed list comprehensions 848 self._assert_raises( 849 template, 850 "Union[List[torch.Tensor], int]", 851 lhs["list_comprehension_of_mixed"], 852 "Arguments for call are not valid", 853 ) 854 855 def test_union_with_dict_assignment(self): 856 template = dedent( 857 """ 858 def fn(): 859 x: {ann} = {lhs} 860 if torch.jit.isinstance(x, Dict[str, torch.Tensor]): 861 x["foo"] = torch.tensor(3) 862 return x 863 """ 864 ) 865 866 lhs = { 867 "dict_literal_empty": "{}", 868 "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}', 869 "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}', 870 "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}', 871 "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \ 872 zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}', 873 "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \ 874 zip(["foo", "bar"], [1, 2]}', 875 "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \ 876 zip(["foo", "bar"], [torch.arange(3), 2])}', 877 "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))", 878 "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])', 879 "dict_keyword_with_empty_iterable": "dict([])", 880 "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])', 881 "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})', 882 "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))', 883 } 884 885 """ 886 Union[Dict[str, torch.Tensor], Dict[str, int]] 887 """ 888 self._assert_raises( 889 template, 890 "Union[List[str], List[torch.Tensor]]", 891 lhs["dict_literal_empty"], 892 "Expected an Union type annotation with an " "inner Dict type", 893 ) 894 895 self._assert_passes( 896 template, 897 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 898 lhs["dict_literal_of_str_tensor"], 899 ) 900 901 self._assert_passes( 902 template, 903 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 904 lhs["dict_literal_of_str_int"], 905 ) 906 907 self._assert_raises( 908 template, 909 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 910 lhs["dict_literal_of_mixed"], 911 "none of those dict types can hold the " 912 "types of the given keys and values", 913 ) 914 915 # TODO: String frontend does not support tuple unpacking 916 # https://github.com/pytorch/pytorch/issues/64096 917 # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 918 # lhs["dict_comprehension_of_str_tensor"]) 919 920 # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 921 # lhs["dict_comprehension_of_str_int"]) 922 923 # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 924 # lhs["dict_comprehension_of_mixed"], 925 # "foobar") 926 927 # self._assert_passes(template, 928 # "Union[Dict[str, torch.Tensor], Dict[str, int]]", 929 # lhs["dict_keyword_with_internal_aggregate_function"]) 930 931 # TODO(@ansley): Follow-up project needed for full type 932 # inference with dict keyword (supported for dict comprehension 933 # and dict literal already; should not be a blocker for anyone) 934 self._assert_raises( 935 template, 936 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 937 lhs["dict_keyword"], 938 "full type inference is not yet supported", 939 ) 940 941 self._assert_raises( 942 template, 943 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 944 lhs["dict_keyword_with_iterable"], 945 "full type inference is not yet supported", 946 ) 947 948 self._assert_raises( 949 template, 950 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 951 lhs["dict_keyword_with_empty_iterable"], 952 "full type inference is not yet supported", 953 ) 954 955 self._assert_raises( 956 template, 957 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 958 lhs["dict_keyword_with_mapping"], 959 "full type inference is not yet supported", 960 ) 961 962 self._assert_raises( 963 template, 964 "Union[Dict[str, torch.Tensor], Dict[str, int]]", 965 lhs["dict_keyword_with_mapping_and_kwargs"], 966 "full type inference is not yet supported", 967 ) 968 969 """ 970 Union[int, torch.Tensor] 971 """ 972 self._assert_raises( 973 template, 974 "Union[int, torch.Tensor]", 975 lhs["dict_literal_empty"], 976 "Expected an Union type annotation with " "an inner Dict type", 977 ) 978 979 self._assert_raises( 980 template, 981 "Union[int, torch.Tensor]", 982 lhs["dict_literal_of_str_tensor"], 983 "Expected an Union type annotation with " "an inner Dict type", 984 ) 985 986 # See above--string frontend does not support tuple unpacking 987 # self._assert_raises(template, "Union[int, torch.Tensor]", 988 # lhs["dict_comprehension_of_tensor"], 989 # "foobar") 990 991 """ 992 Union[Dict[str, torch.Tensor], int] 993 """ 994 self._assert_passes( 995 template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"] 996 ) 997 998 self._assert_passes( 999 template, 1000 "Union[Dict[str, torch.Tensor], int]", 1001 lhs["dict_literal_of_str_tensor"], 1002 ) 1003 1004 self._assert_raises( 1005 template, 1006 "Union[Dict[str, torch.Tensor], int]", 1007 lhs["dict_literal_of_str_int"], 1008 "Type annotation was inferred to be " 1009 r"`Dict\[str, Tensor\]`, but the type of " 1010 "values given by the dict literal is", 1011 ) 1012 1013 self._assert_raises( 1014 template, 1015 "Union[Dict[str, torch.Tensor], int]", 1016 lhs["dict_literal_of_mixed"], 1017 "Type annotation was inferred to be " 1018 r"`Dict\[str, Tensor\]`, but the type of " 1019 "values given by the dict literal is", 1020 ) 1021 1022 self._assert_passes( 1023 template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"] 1024 ) 1025 1026 self._assert_passes( 1027 template, 1028 "Union[Dict[str, torch.Tensor], int]", 1029 lhs["dict_keyword_with_iterable"], 1030 ) 1031 1032 self._assert_passes( 1033 template, 1034 "Union[Dict[str, torch.Tensor], int]", 1035 lhs["dict_keyword_with_empty_iterable"], 1036 ) 1037 1038 self._assert_passes( 1039 template, 1040 "Union[Dict[str, torch.Tensor], int]", 1041 lhs["dict_keyword_with_mapping"], 1042 ) 1043 1044 self._assert_passes( 1045 template, 1046 "Union[Dict[str, torch.Tensor], int]", 1047 lhs["dict_keyword_with_mapping_and_kwargs"], 1048 ) 1049 1050 # See above--string frontend does not support tuple unpacking 1051 # self._assert_passes(template, 1052 # "Union[Dict[str, torch.Tensor], int]", 1053 # lhs["dict_keyword_with_internal_aggregate_function"]) 1054 # 1055 # self._assert_passes(template, 1056 # "Union[Dict[str, torch.Tensor], int]", 1057 # lhs["dict_comprehension_of_str_tensor"]) 1058 1059 # self._assert_raises(template, 1060 # "Union[Dict[str, torch.Tensor], int]", 1061 # lhs["dict_comprehension_of_str_int"], 1062 # "foobar") 1063 1064 # self._assert_raises(template, 1065 # "Union[Dict[str, torch.Tensor], int]", 1066 # lhs["dict_comprehension_of_mixed"], 1067 # "foobar") 1068