xref: /aosp_15_r20/external/pytorch/test/jit/test_union.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6from enum import Enum
7from textwrap import dedent
8from typing import Dict, List, Optional, Tuple, Union
9
10import torch
11from torch.testing import FileCheck
12
13
14# Make the helper files in test/ importable
15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16sys.path.append(pytorch_test_dir)
17from torch.testing._internal.jit_utils import JitTestCase, make_global
18
19
20if __name__ == "__main__":
21    raise RuntimeError(
22        "This test file is not meant to be run directly, use:\n\n"
23        "\tpython test/test_jit.py TESTNAME\n\n"
24        "instead."
25    )
26
27
28class TestUnion(JitTestCase):
29    """
30    This class tests the functionality of `Union`.
31
32    Note: It's important to be able to refine the type of a `Union` to
33    one of its internal types. Currently, there are differences in the
34    way Python expects `isinstance` checks and the way TorchScript
35    expects `isinstance` checks. This means that we can't use
36    `checkScript` in our test cases because either the eager mode or the
37    script mode wouldn't run! So, some test cases have separate but
38    equivalent functions to emulate `checkScript`.
39    """
40
41    def test_check_union_annotation(self):
42        def test_func(a: Union[int, float], b: Optional[int]):
43            return 0
44
45        scripted_func = torch.jit.script(test_func)
46        graph_rep = str(scripted_func.graph)
47        code_rep = str(scripted_func.code)
48        # TS graph IR for Union should be annotated as Union()
49        FileCheck().check("Union(").check("int?").run(graph_rep)
50        # Serialized code for Union should be annotated as Union[]
51        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
52        self.checkScript(test_func, (5, 6))
53        # this shouldn't error out
54        torch._C.parse_ir(str(scripted_func.graph))
55
56    def test_union_with_scalar_values(self):
57        def fn(x: Union[int, float]) -> str:
58            return "foo"
59
60        self.checkScript(fn, (1,))
61        self.checkScript(fn, (1.0,))
62
63        scripted = torch.jit.script(fn)
64
65        with self.assertRaisesRegex(
66            RuntimeError,
67            "Expected a member of"
68            r" Union\[float, int\] but "
69            "instead found type str",
70        ):
71            scripted("1")
72
73    def test_union_with_collections(self):
74        def fn(x: Union[Dict[str, int], List[int]]) -> str:
75            return "foo"
76
77        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
78        self.checkScript(fn, ([1, 2, 3],))
79
80        scripted = torch.jit.script(fn)
81
82        with self.assertRaisesRegex(
83            RuntimeError,
84            "Expected a member of"
85            r" Union\[List\[int\], Dict\[str, "
86            r"int\]\] but instead found type "
87            r"Dict\[str, str\]",
88        ):
89            scripted({"foo": "bar", "baz": "qux"})
90
91        with self.assertRaisesRegex(
92            RuntimeError,
93            "Expected a member of"
94            r" Union\[List\[int\], Dict\[str, "
95            r"int\]\] but instead found type "
96            r"List\[str\]",
97        ):
98            scripted(["foo", "bar", "baz"])
99
100        with self.assertRaisesRegex(
101            RuntimeError,
102            "Expected a member of"
103            r" Union\[List\[int\], Dict\[str, "
104            r"int\]\] but instead found type "
105            "str",
106        ):
107            scripted("1")
108
109    def test_union_with_enum(self):
110        class Color(Enum):
111            RED = 1
112            GREEN = 2
113
114        make_global(Color)
115
116        def fn(x: Union[str, Color]) -> str:
117            return "foo"
118
119        self.checkScript(fn, (Color.RED,))
120        self.checkScript(fn, ("red",))
121
122        scripted = torch.jit.script(fn)
123
124        with self.assertRaisesRegex(
125            RuntimeError,
126            "Expected a member of"
127            r" Union\[__torch__.jit.test_union."
128            r"Color, str\] but instead found "
129            "type int",
130        ):
131            scripted(1)
132
133    def test_union_in_class_constructor(self):
134        @torch.jit.script  # noqa: B903
135        class A:  # noqa: B903
136            def __init__(self, x: Union[int, str]) -> None:
137                self.x = x
138
139        def fn(x: Union[str, int]) -> A:
140            return A(x)
141
142        self.assertEqual(fn("foo").x, "foo")
143        self.assertEqual(fn(1).x, 1)
144
145        scripted = torch.jit.script(fn)
146
147        with self.assertRaisesRegex(
148            RuntimeError,
149            "Expected a member of"
150            r" Union\[int, str\] but instead "
151            r"found type List\[str\]",
152        ):
153            scripted(["foo", "bar", "baz"])
154
155    def test_union_return_type(self):
156        def fn(x: int) -> Union[int, str]:
157            return "foo"
158
159        self.checkScript(fn, (1,))
160
161    def test_union_as_annotation(self):
162        def fn() -> Union[int, str]:
163            x: Union[int, str] = "foo"
164            return x
165
166        self.checkScript(fn, ())
167
168    def test_union_as_annotation_in_typed_container(self):
169        def fn() -> None:
170            l: List[Union[int, str]] = []
171            u1: Union[int, str] = "foo"
172            u2: Union[int, str] = 1
173            l.append(u1)
174            l.append(u2)
175
176        self.checkScript(fn, ())
177
178    def test_union_as_annotation_py2(self):
179        def fn():
180            # type: () -> Union[int, str]
181            x: Union[int, str] = "foo"
182            return x
183
184        self.checkScript(fn, ())
185
186    def test_union_as_internal_tuple_type(self):
187        def fn():
188            t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
189            return t
190
191        self.checkScript(fn, ())
192
193    def test_union_variable_can_be_reassigned(self):
194        @torch.jit.script
195        def aux1(i: int):
196            return int(i**2)
197
198        @torch.jit.script
199        def aux2(s: str):
200            return s + s
201
202        def fn() -> Union[int, str]:
203            x: Union[int, str] = "foo"
204            i: int = 1
205            x = i
206            y: int = aux1(x)
207            z: str = aux2(str(y))
208            x = z
209            return x
210
211        self.checkScript(fn, ())
212
213    def test_union_does_not_replace_existing_annotated_type(self):
214        def fn():
215            x: List[int] = [1, 2, 3]
216            x.append("foo")
217            return x
218
219        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
220            scripted = torch.jit.script(fn)
221            scripted()
222
223    def test_union_does_not_replace_existing_annotated_type_union(self):
224        def fn():
225            x: List[Union[int, str]] = [1, "foo", 3]
226            x.append(2.0)
227            return x
228
229        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
230            scripted = torch.jit.script(fn)
231            scripted()
232
233    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
234        def fn():
235            x: List[int] = []
236            x.append("foo")
237            return x
238
239        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
240            scripted = torch.jit.script(fn)
241            scripted()
242
243    def test_unions_of_unions_are_flattened(self):
244        @torch.jit.script
245        def fn(x: Union[Union[int, str], float]) -> str:
246            return "foo"
247
248        s = fn.graph
249
250        FileCheck().check("x : Union(float, int, str)").run(s)
251
252    def test_unions_of_a_single_argument_vanish(self):
253        @torch.jit.script
254        def fn(x: Union[int]) -> str:
255            return "foo"
256
257        s = fn.graph
258
259        FileCheck().check("x : int").run(s)
260
261    def test_union_redundant_arguments_are_skipped(self):
262        @torch.jit.script
263        def fn(x: Union[int, str, int]) -> str:
264            return "foo"
265
266        s = fn.graph
267
268        FileCheck().check("x : Union(int, str)").run(s)
269
270    def test_union_redundant_arguments_are_skipped_optional(self):
271        @torch.jit.script
272        def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
273            return "foo"
274
275        s = fn.graph
276
277        FileCheck().check("x : Union(float, int, NoneType)").run(s)
278
279    def test_union_redundant_arguments_are_skipped_subtyping(self):
280        @torch.jit.script
281        def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
282            return "foo"
283
284        s = fn.graph
285
286        FileCheck().check("x : Union((int?, int), str)").run(s)
287
288    def test_union_redundant_arguments_are_skipped_container(self):
289        @torch.jit.script
290        def fn(x: Union[List[str], List[float], List[str]]) -> str:
291            return "foo"
292
293        s = fn.graph
294
295        FileCheck().check("x : Union(float[], str[])").run(s)
296
297    def test_union_argument_order_is_ignored(self):
298        @torch.jit.script
299        def fn1(x: Union[int, str]) -> str:
300            return "foo"
301
302        @torch.jit.script
303        def fn2(x: Union[str, int]) -> str:
304            return "foo"
305
306        for s in (fn1.graph, fn2.graph):
307            FileCheck().check("x : Union(int, str)").run(s)
308
309    def test_union_argument_order_is_ignored_container(self):
310        @torch.jit.script
311        def fn1(x: Union[List[str], List[int]]) -> str:
312            return "foo"
313
314        @torch.jit.script
315        def fn2(x: Union[List[int], List[str]]) -> str:
316            return "foo"
317
318        for s in (fn1.graph, fn2.graph):
319            FileCheck().check("x : Union(int[], str[])").run(s)
320
321    def test_union_T_None_is_equivalent_to_optional_T(self):
322        @torch.jit.script
323        def inner(x: Union[int, None]) -> int:
324            if x is not None:
325                return x
326            else:
327                return 5
328
329        @torch.jit.script
330        def fn1() -> int:
331            a: Optional[int] = 5
332            b: Optional[int] = None
333            a_ = inner(a)
334            b_ = inner(b)
335            return a_ + b_
336
337        self.assertEqual(fn1(), 10)
338
339        @torch.jit.script
340        def inner2(x: Optional[int]) -> int:
341            if x is not None:
342                return x
343            else:
344                return 5
345
346        @torch.jit.script
347        def fn2() -> int:
348            a: Union[int, None] = 5
349            b: Union[int, None] = None
350            a_ = inner(a)
351            b_ = inner(b)
352            return a_ + b_
353
354        self.assertEqual(fn2(), 10)
355
356    def test_union_optional_of_union_is_flattened(self):
357        @torch.jit.script
358        def fn(flag: int) -> Union[str, int, None]:
359            y: Union[int, str, None] = "foo"
360            if flag == 0:
361                x: Optional[Union[int, str]] = y
362            elif flag == 1:
363                x: Optional[Union[int, str]] = 1
364            else:
365                x: Optional[Union[int, str]] = None
366            return x
367
368        # Can't use `checkScript` because it will flag the fact that
369        # the original code has `Optional[Union[int, str]]` but the
370        # saved/loaded code has `Union[int, NoneType, str]` (even
371        # though this is exactly what we want)
372        self.assertEqual(fn(0), "foo")
373        self.assertEqual(fn(1), 1)
374        self.assertEqual(fn(2), None)
375
376        buffer = io.BytesIO()
377        torch.jit.save(fn, buffer)
378        buffer = io.BytesIO(buffer.getvalue())
379        l = torch.jit.load(buffer)
380
381        s = l.code
382
383        FileCheck().check("Union[int, NoneType, str]").check(
384            "Union[int, NoneType, str]"
385        ).run(s)
386
387    def test_union_subclasses_larger_union(self):
388        def fn() -> Union[int, str, torch.Tensor]:
389            x: Union[int, str] = "foo"
390            return x
391
392        self.checkScript(fn, ())
393
394    # TODO: We would like to eventually support this. The issue is being
395    # tracked at https://github.com/pytorch/pytorch/issues/58167
396    def test_union_as_dict_key(self):
397        def fn():
398            x: Dict[Union[int, str], str] = {}
399            x["foo"] = "bar"
400            x[1] = 2
401            return x[1]
402
403        with self.assertRaisesRegex(
404            RuntimeError,
405            "only int, float, "
406            "complex, Tensor, device and string keys "
407            "are supported",
408        ):
409            torch.jit.script(fn)
410
411    def test_union_as_dict_value(self):
412        def fn():
413            x: Dict[str, Union[int, str]] = {}
414            x["foo"] = "bar"
415            x["baz"] = 2
416            return x["baz"]
417
418        self.checkScript(fn, ())
419
420    def test_union_module_with_union_instance_variable(self):
421        class M(torch.nn.Module):
422            x: Union[int, str]
423
424            def __init__(self, x: Union[int, str]):
425                super().__init__()
426                self.x: Union[int, str] = x
427
428            def forward(self, y: Union[int, str]):
429                self.x = y
430                return self.x
431
432        self.checkModule(
433            M(
434                2,
435            ),
436            (1,),
437        )
438        self.checkModule(M("bar"), ("foo",))
439
440    def test_union_module_with_union_class_variable(self):
441        class M(torch.nn.Module):
442            x: Union[int, str] = "foo"
443
444            def __init__(self, y: int):
445                super().__init__()
446                x = y
447
448            def forward(self, z: str):
449                x = z
450                return x
451
452        self.checkModule(M(1), ("foo",))
453
454    def test_union_type_refinement(self):
455        def fn(x: Union[int, str]) -> str:
456            if isinstance(x, str):
457                z = x + "bar"
458                return x
459            else:
460                return "baz"
461
462        self.checkScript(fn, ("foo",))
463        self.checkScript(fn, (1,))
464
465    def test_union_type_refinement_union_rhs(self):
466        def fn(x: int) -> str:
467            if torch.jit.isinstance(x, Union[int, str]):
468                return "bar"
469            else:
470                return "baz"
471
472        self.checkScript(fn, (1,))
473
474    def test_union_type_refinement_tuple_rhs(self):
475        def fn(x: Union[int, float, List[str]]) -> str:
476            if isinstance(x, (int, float)):
477                if isinstance(x, int):
478                    return str(x)
479                else:
480                    return "foo"
481            else:
482                if len(x):
483                    return x[0]
484                else:
485                    return "bar"
486
487        self.checkScript(fn, (1,))
488        self.checkScript(fn, (1.0,))
489        self.checkScript(fn, (["a", "b", "c"],))
490
491    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
492        def fn(x: Union[int, List[str]]) -> str:
493            if isinstance(x, (int, float)):
494                y = x + x
495                return str(y)
496            else:
497                if len(x):
498                    return x[0]
499                else:
500                    return "bar"
501
502        self.checkScript(fn, (1,))
503        self.checkScript(fn, (["a", "b", "c"],))
504
505    def test_union_type_refinement_tuple_rhs_union(self):
506        @torch.jit.script
507        def fn(x: int) -> str:
508            if torch.jit.isinstance(x, (Union[int, str], float)):
509                y = x + x
510                return str(y)
511            else:
512                return "foo"
513
514        # TODO: There's currently an unrelated bug in
515        # `torch.jit.isinstance` that makes it fail for tuple literals.
516        # Posted here: https://github.com/pytorch/pytorch/issues/60095
517        # Change `assertEqual` to `checkScript` when the bug is fixed
518        self.assertEqual(fn(1), "2")
519
520    def test_union_type_refinement_statically_false(self):
521        @torch.jit.script
522        def fn(x: int) -> str:
523            if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
524                z = x + "foo"
525                return z
526            else:
527                return "bar"
528
529        s = fn.graph
530
531        # Check that we don't have any branching statements
532        FileCheck().check_not("block0()").check_not("block1()").run(s)
533
534    def test_union_type_refinement_statically_true(self):
535        @torch.jit.script
536        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
537            if not torch.jit.isinstance(x, (int, List[int])):
538                return x
539            else:
540                l = [1, 2, 3]
541                y: Union[List[int], int] = l
542                return y
543
544        s = fn.graph
545
546        # Check that we don't have any branching statements
547        FileCheck().check_not("block0()").check_not("block1()").run(s)
548
549    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
550        def fn(x: Union[List[int], int]) -> int:
551            if torch.jit.isinstance(x, (int, float, str)):
552                # We should know that `x` is an `int` here
553                z = x + 1
554                return z
555            else:
556                return 100
557
558        self.checkScript(fn, ([1, 2, 3],))
559        self.checkScript(fn, (1,))
560
561    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
562        def fn(x: Union[List[int], int]) -> int:
563            if torch.jit.isinstance(x, Union[int, float, str]):
564                # We should know that `x` is an `int` here
565                z = x + 1
566                return z
567            else:
568                return 100
569
570        self.checkScript(fn, ([1, 2, 3],))
571        self.checkScript(fn, (1,))
572
573    def test_union_type_refinement_internal_declaration(self):
574        def fn(flag: bool) -> str:
575            x: Union[int, str, None] = None
576            if flag:
577                y = "foo"
578            else:
579                y = 1
580            if isinstance(x, str):
581                return x
582            else:
583                return "bar"
584
585        self.checkScript(fn, (True,))
586        self.checkScript(fn, (False,))
587
588    def test_union_branching_with_union_return_and_homogenous_types(self):
589        def fn(x: int) -> Union[int, str]:
590            if x % 2:
591                return "foo"
592            else:
593                return "bar"
594
595        self.checkScript(fn, (1,))
596        self.checkScript(fn, (8,))
597
598    def test_union_branching_does_not_autoinfer_undeclared_union(self):
599        def fn(x: int) -> str:
600            if x % 2:
601                y = "foo"
602            else:
603                y = x
604            if isinstance(y, str):
605                return y
606            else:
607                return "bar"
608
609        with self.assertRaisesRegex(
610            RuntimeError,
611            "y is set to type str"
612            " in the true branch and type int "
613            "in the false branch",
614        ):
615            torch.jit.script(fn)
616
617    def test_union_branching_does_not_widen_existing_inferred_type(self):
618        def fn(x: int) -> str:
619            y = "foo"
620            if x % 2:
621                y = "bar"
622            else:
623                y = x
624            if isinstance(y, str):
625                return y
626            else:
627                return "baz"
628
629        with self.assertRaisesRegex(
630            RuntimeError,
631            "previously had type "
632            "str but is now being assigned to a"
633            " value of type int",
634        ):
635            torch.jit.script(fn)
636
637    def test_union_schema_matching_on_internal_type(self):
638        def fn(x: Union[List[int], Dict[str, int]]) -> int:
639            if torch.jit.isinstance(x, List[int]):
640                return x[0]
641            else:
642                return list(x.values())[0]
643
644        self.checkScript(fn, ([1, 2, 3],))
645        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
646
647    def test_union_subtractive_refinement(self):
648        def fn(x: Union[List[int], int]) -> int:
649            if not isinstance(x, int):
650                x.append(1)
651                return x[0]
652            else:
653                return x
654
655        self.checkScript(fn, (1,))
656        self.checkScript(fn, ([1, 2, 3],))
657
658    def test_union_subtractive_refinement_with_container(self):
659        def fn(x: Union[List[int], int]) -> int:
660            if not torch.jit.isinstance(x, List[int]):
661                return x
662            else:
663                x.append(1)
664                return x[0]
665
666        self.checkScript(fn, (1,))
667        self.checkScript(fn, ([1, 2, 3],))
668
669    def test_union_memory_aliasing(self):
670        def fn():
671            x: List[torch.Tensor] = []
672            z: List[Optional[List[torch.Tensor]]] = []
673            z.append(x)
674            x_alias = z[0]
675            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
676                x_alias.append(torch.tensor(3))
677            return x
678
679        self.checkScript(fn, ())
680
681    def test_union_serialization_preserves_type_annotations(self):
682        # This function will fail after being torch.jit.save'd and
683        # torch.jit.load'd if the type annotations aren't preserved
684        # for Union during serialization. We need the `Union[str, int]`
685        # annotation to make sure that `y` is typed as a Union instead
686        # of as a str in one branch and an int in the other
687        def fn(x: int) -> str:
688            if x % 2:
689                y: Union[str, int] = "bar"
690            else:
691                y: Union[str, int] = x
692            if isinstance(y, str):
693                return y
694            else:
695                return "baz"
696
697        self.checkScript(fn, (1,))
698        self.checkScript(fn, (8,))
699
700    def _assert_passes(self, template: str, ann: str, lhs: str):
701        code = template.format(ann=ann, lhs=lhs)
702        self.checkScript(code, (), name="fn")
703
704    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
705        code = template.format(ann=ann, lhs=lhs)
706        with self.assertRaisesRegex(RuntimeError, msg):
707            cu = torch.jit.CompilationUnit(code, _frames_up=1)
708            string_frontend = getattr(cu, "fn")  # noqa: B009
709
710    def test_union_with_list_assignment(self):
711        template = dedent(
712            """
713            def fn():
714                x: {ann} = {lhs}
715                if torch.jit.isinstance(x, List[torch.Tensor]):
716                    x.append(torch.tensor(3))
717                return x
718        """
719        )
720
721        lhs = {
722            "list_literal_empty": "[]",
723            "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
724            "list_literal_of_str": '["foo", "bar", "baz"]',
725            "list_literal_of_mixed": "[torch.arange(5), 1]",
726            "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
727            "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
728            "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
729        }
730
731        """
732        Union[List[str], List[torch.Tensor]]
733        """
734        self._assert_raises(
735            template,
736            "Union[List[str], List[torch.Tensor]]",
737            lhs["list_literal_empty"],
738            "there are multiple possible List type "
739            "candidates in the Union annotation",
740        )
741
742        self._assert_passes(
743            template,
744            "Union[List[str], List[torch.Tensor]]",
745            lhs["list_literal_of_tensor"],
746        )
747
748        self._assert_passes(
749            template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
750        )
751
752        self._assert_raises(
753            template,
754            "Union[List[str], List[torch.Tensor]]",
755            lhs["list_literal_of_mixed"],
756            "none of those types match the types of the" " given list elements",
757        )
758
759        self._assert_passes(
760            template,
761            "Union[List[str], List[torch.Tensor]]",
762            lhs["list_comprehension_of_tensor"],
763        )
764
765        self._assert_passes(
766            template,
767            "Union[List[str], List[torch.Tensor]]",
768            lhs["list_comprehension_of_str"],
769        )
770
771        # TODO: Support mixed list comprehensions
772        self._assert_raises(
773            template,
774            "Union[List[str], List[torch.Tensor]]",
775            lhs["list_comprehension_of_mixed"],
776            "Arguments for call are not valid",
777        )
778
779        """
780        Union[int, torch.Tensor]
781        """
782        self._assert_raises(
783            template,
784            "Union[int, torch.Tensor]",
785            lhs["list_literal_empty"],
786            "Expected an Union type annotation with an " "inner List type",
787        )
788
789        self._assert_raises(
790            template,
791            "Union[int, torch.Tensor]",
792            lhs["list_literal_of_tensor"],
793            "Expected an Union type annotation with an " "inner List type",
794        )
795
796        self._assert_raises(
797            template,
798            "Union[int, torch.Tensor]",
799            lhs["list_comprehension_of_tensor"],
800            "Expected an Union type annotation with an " "inner List type",
801        )
802
803        """
804        Union[List[torch.Tensor], int]
805        """
806        self._assert_passes(
807            template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
808        )
809
810        self._assert_passes(
811            template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
812        )
813
814        self._assert_raises(
815            template,
816            "Union[List[torch.Tensor], int]",
817            lhs["list_literal_of_str"],
818            r"List type annotation `List\[Tensor\]` did "
819            "not match the types of the given list "
820            "elements",
821        )
822
823        self._assert_raises(
824            template,
825            "Union[List[torch.Tensor], int]",
826            lhs["list_literal_of_mixed"],
827            r"List type annotation `List\[Tensor\]` did "
828            "not match the types of the given list "
829            "elements",
830        )
831
832        self._assert_passes(
833            template,
834            "Union[List[torch.Tensor], int]",
835            lhs["list_comprehension_of_tensor"],
836        )
837
838        self._assert_raises(
839            template,
840            "Union[List[torch.Tensor], int]",
841            lhs["list_comprehension_of_str"],
842            r"List type annotation `List\[Tensor\]` did "
843            "not match the types of the given list "
844            "elements",
845        )
846
847        # TODO(@ansley): Support mixed list comprehensions
848        self._assert_raises(
849            template,
850            "Union[List[torch.Tensor], int]",
851            lhs["list_comprehension_of_mixed"],
852            "Arguments for call are not valid",
853        )
854
855    def test_union_with_dict_assignment(self):
856        template = dedent(
857            """
858            def fn():
859                x: {ann} = {lhs}
860                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
861                    x["foo"] = torch.tensor(3)
862                return x
863        """
864        )
865
866        lhs = {
867            "dict_literal_empty": "{}",
868            "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
869            "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
870            "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
871            "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
872                    zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
873            "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
874                    zip(["foo", "bar"], [1, 2]}',
875            "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
876                    zip(["foo", "bar"], [torch.arange(3), 2])}',
877            "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
878            "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
879            "dict_keyword_with_empty_iterable": "dict([])",
880            "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
881            "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
882            "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
883        }
884
885        """
886        Union[Dict[str, torch.Tensor], Dict[str, int]]
887        """
888        self._assert_raises(
889            template,
890            "Union[List[str], List[torch.Tensor]]",
891            lhs["dict_literal_empty"],
892            "Expected an Union type annotation with an " "inner Dict type",
893        )
894
895        self._assert_passes(
896            template,
897            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
898            lhs["dict_literal_of_str_tensor"],
899        )
900
901        self._assert_passes(
902            template,
903            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
904            lhs["dict_literal_of_str_int"],
905        )
906
907        self._assert_raises(
908            template,
909            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
910            lhs["dict_literal_of_mixed"],
911            "none of those dict types can hold the "
912            "types of the given keys and values",
913        )
914
915        # TODO: String frontend does not support tuple unpacking
916        # https://github.com/pytorch/pytorch/issues/64096
917        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
918        #              lhs["dict_comprehension_of_str_tensor"])
919
920        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
921        #              lhs["dict_comprehension_of_str_int"])
922
923        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
924        #              lhs["dict_comprehension_of_mixed"],
925        #              "foobar")
926
927        # self._assert_passes(template,
928        #                    "Union[Dict[str, torch.Tensor], Dict[str, int]]",
929        #                    lhs["dict_keyword_with_internal_aggregate_function"])
930
931        # TODO(@ansley): Follow-up project needed for full type
932        # inference with dict keyword (supported for dict comprehension
933        # and dict literal already; should not be a blocker for anyone)
934        self._assert_raises(
935            template,
936            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
937            lhs["dict_keyword"],
938            "full type inference is not yet supported",
939        )
940
941        self._assert_raises(
942            template,
943            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
944            lhs["dict_keyword_with_iterable"],
945            "full type inference is not yet supported",
946        )
947
948        self._assert_raises(
949            template,
950            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
951            lhs["dict_keyword_with_empty_iterable"],
952            "full type inference is not yet supported",
953        )
954
955        self._assert_raises(
956            template,
957            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
958            lhs["dict_keyword_with_mapping"],
959            "full type inference is not yet supported",
960        )
961
962        self._assert_raises(
963            template,
964            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
965            lhs["dict_keyword_with_mapping_and_kwargs"],
966            "full type inference is not yet supported",
967        )
968
969        """
970        Union[int, torch.Tensor]
971        """
972        self._assert_raises(
973            template,
974            "Union[int, torch.Tensor]",
975            lhs["dict_literal_empty"],
976            "Expected an Union type annotation with " "an inner Dict type",
977        )
978
979        self._assert_raises(
980            template,
981            "Union[int, torch.Tensor]",
982            lhs["dict_literal_of_str_tensor"],
983            "Expected an Union type annotation with " "an inner Dict type",
984        )
985
986        # See above--string frontend does not support tuple unpacking
987        # self._assert_raises(template, "Union[int, torch.Tensor]",
988        #              lhs["dict_comprehension_of_tensor"],
989        #              "foobar")
990
991        """
992        Union[Dict[str, torch.Tensor], int]
993        """
994        self._assert_passes(
995            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
996        )
997
998        self._assert_passes(
999            template,
1000            "Union[Dict[str, torch.Tensor], int]",
1001            lhs["dict_literal_of_str_tensor"],
1002        )
1003
1004        self._assert_raises(
1005            template,
1006            "Union[Dict[str, torch.Tensor], int]",
1007            lhs["dict_literal_of_str_int"],
1008            "Type annotation was inferred to be "
1009            r"`Dict\[str, Tensor\]`, but the type of "
1010            "values given by the dict literal is",
1011        )
1012
1013        self._assert_raises(
1014            template,
1015            "Union[Dict[str, torch.Tensor], int]",
1016            lhs["dict_literal_of_mixed"],
1017            "Type annotation was inferred to be "
1018            r"`Dict\[str, Tensor\]`, but the type of "
1019            "values given by the dict literal is",
1020        )
1021
1022        self._assert_passes(
1023            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
1024        )
1025
1026        self._assert_passes(
1027            template,
1028            "Union[Dict[str, torch.Tensor], int]",
1029            lhs["dict_keyword_with_iterable"],
1030        )
1031
1032        self._assert_passes(
1033            template,
1034            "Union[Dict[str, torch.Tensor], int]",
1035            lhs["dict_keyword_with_empty_iterable"],
1036        )
1037
1038        self._assert_passes(
1039            template,
1040            "Union[Dict[str, torch.Tensor], int]",
1041            lhs["dict_keyword_with_mapping"],
1042        )
1043
1044        self._assert_passes(
1045            template,
1046            "Union[Dict[str, torch.Tensor], int]",
1047            lhs["dict_keyword_with_mapping_and_kwargs"],
1048        )
1049
1050        # See above--string frontend does not support tuple unpacking
1051        # self._assert_passes(template,
1052        #                    "Union[Dict[str, torch.Tensor], int]",
1053        #                    lhs["dict_keyword_with_internal_aggregate_function"])
1054        #
1055        # self._assert_passes(template,
1056        #                    "Union[Dict[str, torch.Tensor], int]",
1057        #                    lhs["dict_comprehension_of_str_tensor"])
1058
1059        # self._assert_raises(template,
1060        #                    "Union[Dict[str, torch.Tensor], int]",
1061        #                    lhs["dict_comprehension_of_str_int"],
1062        #                    "foobar")
1063
1064        # self._assert_raises(template,
1065        #                    "Union[Dict[str, torch.Tensor], int]",
1066        #                    lhs["dict_comprehension_of_mixed"],
1067        #                    "foobar")
1068