xref: /aosp_15_r20/external/pytorch/test/jit/test_union.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum
7*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple, Union
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport torch
11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
15*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
21*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
22*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
23*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
24*da0073e9SAndroid Build Coastguard Worker        "instead."
25*da0073e9SAndroid Build Coastguard Worker    )
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass TestUnion(JitTestCase):
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker    This class tests the functionality of `Union`.
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    Note: It's important to be able to refine the type of a `Union` to
33*da0073e9SAndroid Build Coastguard Worker    one of its internal types. Currently, there are differences in the
34*da0073e9SAndroid Build Coastguard Worker    way Python expects `isinstance` checks and the way TorchScript
35*da0073e9SAndroid Build Coastguard Worker    expects `isinstance` checks. This means that we can't use
36*da0073e9SAndroid Build Coastguard Worker    `checkScript` in our test cases because either the eager mode or the
37*da0073e9SAndroid Build Coastguard Worker    script mode wouldn't run! So, some test cases have separate but
38*da0073e9SAndroid Build Coastguard Worker    equivalent functions to emulate `checkScript`.
39*da0073e9SAndroid Build Coastguard Worker    """
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_check_union_annotation(self):
42*da0073e9SAndroid Build Coastguard Worker        def test_func(a: Union[int, float], b: Optional[int]):
43*da0073e9SAndroid Build Coastguard Worker            return 0
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        scripted_func = torch.jit.script(test_func)
46*da0073e9SAndroid Build Coastguard Worker        graph_rep = str(scripted_func.graph)
47*da0073e9SAndroid Build Coastguard Worker        code_rep = str(scripted_func.code)
48*da0073e9SAndroid Build Coastguard Worker        # TS graph IR for Union should be annotated as Union()
49*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Union(").check("int?").run(graph_rep)
50*da0073e9SAndroid Build Coastguard Worker        # Serialized code for Union should be annotated as Union[]
51*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
52*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_func, (5, 6))
53*da0073e9SAndroid Build Coastguard Worker        # this shouldn't error out
54*da0073e9SAndroid Build Coastguard Worker        torch._C.parse_ir(str(scripted_func.graph))
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def test_union_with_scalar_values(self):
57*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, float]) -> str:
58*da0073e9SAndroid Build Coastguard Worker            return "foo"
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
61*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.0,))
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(fn)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
66*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
67*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
68*da0073e9SAndroid Build Coastguard Worker            r" Union\[float, int\] but "
69*da0073e9SAndroid Build Coastguard Worker            "instead found type str",
70*da0073e9SAndroid Build Coastguard Worker        ):
71*da0073e9SAndroid Build Coastguard Worker            scripted("1")
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def test_union_with_collections(self):
74*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[Dict[str, int], List[int]]) -> str:
75*da0073e9SAndroid Build Coastguard Worker            return "foo"
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
78*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(fn)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
83*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
84*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
85*da0073e9SAndroid Build Coastguard Worker            r" Union\[List\[int\], Dict\[str, "
86*da0073e9SAndroid Build Coastguard Worker            r"int\]\] but instead found type "
87*da0073e9SAndroid Build Coastguard Worker            r"Dict\[str, str\]",
88*da0073e9SAndroid Build Coastguard Worker        ):
89*da0073e9SAndroid Build Coastguard Worker            scripted({"foo": "bar", "baz": "qux"})
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
92*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
93*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
94*da0073e9SAndroid Build Coastguard Worker            r" Union\[List\[int\], Dict\[str, "
95*da0073e9SAndroid Build Coastguard Worker            r"int\]\] but instead found type "
96*da0073e9SAndroid Build Coastguard Worker            r"List\[str\]",
97*da0073e9SAndroid Build Coastguard Worker        ):
98*da0073e9SAndroid Build Coastguard Worker            scripted(["foo", "bar", "baz"])
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
101*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
102*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
103*da0073e9SAndroid Build Coastguard Worker            r" Union\[List\[int\], Dict\[str, "
104*da0073e9SAndroid Build Coastguard Worker            r"int\]\] but instead found type "
105*da0073e9SAndroid Build Coastguard Worker            "str",
106*da0073e9SAndroid Build Coastguard Worker        ):
107*da0073e9SAndroid Build Coastguard Worker            scripted("1")
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def test_union_with_enum(self):
110*da0073e9SAndroid Build Coastguard Worker        class Color(Enum):
111*da0073e9SAndroid Build Coastguard Worker            RED = 1
112*da0073e9SAndroid Build Coastguard Worker            GREEN = 2
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        make_global(Color)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[str, Color]) -> str:
117*da0073e9SAndroid Build Coastguard Worker            return "foo"
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (Color.RED,))
120*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("red",))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(fn)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
125*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
126*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
127*da0073e9SAndroid Build Coastguard Worker            r" Union\[__torch__.jit.test_union."
128*da0073e9SAndroid Build Coastguard Worker            r"Color, str\] but instead found "
129*da0073e9SAndroid Build Coastguard Worker            "type int",
130*da0073e9SAndroid Build Coastguard Worker        ):
131*da0073e9SAndroid Build Coastguard Worker            scripted(1)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def test_union_in_class_constructor(self):
134*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script  # noqa: B903
135*da0073e9SAndroid Build Coastguard Worker        class A:  # noqa: B903
136*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x: Union[int, str]) -> None:
137*da0073e9SAndroid Build Coastguard Worker                self.x = x
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[str, int]) -> A:
140*da0073e9SAndroid Build Coastguard Worker            return A(x)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn("foo").x, "foo")
143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(1).x, 1)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        scripted = torch.jit.script(fn)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
148*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
149*da0073e9SAndroid Build Coastguard Worker            "Expected a member of"
150*da0073e9SAndroid Build Coastguard Worker            r" Union\[int, str\] but instead "
151*da0073e9SAndroid Build Coastguard Worker            r"found type List\[str\]",
152*da0073e9SAndroid Build Coastguard Worker        ):
153*da0073e9SAndroid Build Coastguard Worker            scripted(["foo", "bar", "baz"])
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    def test_union_return_type(self):
156*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> Union[int, str]:
157*da0073e9SAndroid Build Coastguard Worker            return "foo"
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def test_union_as_annotation(self):
162*da0073e9SAndroid Build Coastguard Worker        def fn() -> Union[int, str]:
163*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str] = "foo"
164*da0073e9SAndroid Build Coastguard Worker            return x
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_union_as_annotation_in_typed_container(self):
169*da0073e9SAndroid Build Coastguard Worker        def fn() -> None:
170*da0073e9SAndroid Build Coastguard Worker            l: List[Union[int, str]] = []
171*da0073e9SAndroid Build Coastguard Worker            u1: Union[int, str] = "foo"
172*da0073e9SAndroid Build Coastguard Worker            u2: Union[int, str] = 1
173*da0073e9SAndroid Build Coastguard Worker            l.append(u1)
174*da0073e9SAndroid Build Coastguard Worker            l.append(u2)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def test_union_as_annotation_py2(self):
179*da0073e9SAndroid Build Coastguard Worker        def fn():
180*da0073e9SAndroid Build Coastguard Worker            # type: () -> Union[int, str]
181*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str] = "foo"
182*da0073e9SAndroid Build Coastguard Worker            return x
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def test_union_as_internal_tuple_type(self):
187*da0073e9SAndroid Build Coastguard Worker        def fn():
188*da0073e9SAndroid Build Coastguard Worker            t: Tuple[Union[int, str], Union[int, str]] = (1, "foo")
189*da0073e9SAndroid Build Coastguard Worker            return t
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_union_variable_can_be_reassigned(self):
194*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
195*da0073e9SAndroid Build Coastguard Worker        def aux1(i: int):
196*da0073e9SAndroid Build Coastguard Worker            return int(i**2)
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
199*da0073e9SAndroid Build Coastguard Worker        def aux2(s: str):
200*da0073e9SAndroid Build Coastguard Worker            return s + s
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        def fn() -> Union[int, str]:
203*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str] = "foo"
204*da0073e9SAndroid Build Coastguard Worker            i: int = 1
205*da0073e9SAndroid Build Coastguard Worker            x = i
206*da0073e9SAndroid Build Coastguard Worker            y: int = aux1(x)
207*da0073e9SAndroid Build Coastguard Worker            z: str = aux2(str(y))
208*da0073e9SAndroid Build Coastguard Worker            x = z
209*da0073e9SAndroid Build Coastguard Worker            return x
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker    def test_union_does_not_replace_existing_annotated_type(self):
214*da0073e9SAndroid Build Coastguard Worker        def fn():
215*da0073e9SAndroid Build Coastguard Worker            x: List[int] = [1, 2, 3]
216*da0073e9SAndroid Build Coastguard Worker            x.append("foo")
217*da0073e9SAndroid Build Coastguard Worker            return x
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
220*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(fn)
221*da0073e9SAndroid Build Coastguard Worker            scripted()
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker    def test_union_does_not_replace_existing_annotated_type_union(self):
224*da0073e9SAndroid Build Coastguard Worker        def fn():
225*da0073e9SAndroid Build Coastguard Worker            x: List[Union[int, str]] = [1, "foo", 3]
226*da0073e9SAndroid Build Coastguard Worker            x.append(2.0)
227*da0073e9SAndroid Build Coastguard Worker            return x
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
230*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(fn)
231*da0073e9SAndroid Build Coastguard Worker            scripted()
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
234*da0073e9SAndroid Build Coastguard Worker        def fn():
235*da0073e9SAndroid Build Coastguard Worker            x: List[int] = []
236*da0073e9SAndroid Build Coastguard Worker            x.append("foo")
237*da0073e9SAndroid Build Coastguard Worker            return x
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
240*da0073e9SAndroid Build Coastguard Worker            scripted = torch.jit.script(fn)
241*da0073e9SAndroid Build Coastguard Worker            scripted()
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def test_unions_of_unions_are_flattened(self):
244*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
245*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[Union[int, str], float]) -> str:
246*da0073e9SAndroid Build Coastguard Worker            return "foo"
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : Union(float, int, str)").run(s)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    def test_unions_of_a_single_argument_vanish(self):
253*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
254*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int]) -> str:
255*da0073e9SAndroid Build Coastguard Worker            return "foo"
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : int").run(s)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    def test_union_redundant_arguments_are_skipped(self):
262*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
263*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, str, int]) -> str:
264*da0073e9SAndroid Build Coastguard Worker            return "foo"
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : Union(int, str)").run(s)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    def test_union_redundant_arguments_are_skipped_optional(self):
271*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
272*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, Optional[float], Optional[int]]) -> str:
273*da0073e9SAndroid Build Coastguard Worker            return "foo"
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : Union(float, int, NoneType)").run(s)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker    def test_union_redundant_arguments_are_skipped_subtyping(self):
280*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
281*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str:
282*da0073e9SAndroid Build Coastguard Worker            return "foo"
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : Union((int?, int), str)").run(s)
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker    def test_union_redundant_arguments_are_skipped_container(self):
289*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
290*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[str], List[float], List[str]]) -> str:
291*da0073e9SAndroid Build Coastguard Worker            return "foo"
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("x : Union(float[], str[])").run(s)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_union_argument_order_is_ignored(self):
298*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
299*da0073e9SAndroid Build Coastguard Worker        def fn1(x: Union[int, str]) -> str:
300*da0073e9SAndroid Build Coastguard Worker            return "foo"
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
303*da0073e9SAndroid Build Coastguard Worker        def fn2(x: Union[str, int]) -> str:
304*da0073e9SAndroid Build Coastguard Worker            return "foo"
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        for s in (fn1.graph, fn2.graph):
307*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("x : Union(int, str)").run(s)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    def test_union_argument_order_is_ignored_container(self):
310*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
311*da0073e9SAndroid Build Coastguard Worker        def fn1(x: Union[List[str], List[int]]) -> str:
312*da0073e9SAndroid Build Coastguard Worker            return "foo"
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
315*da0073e9SAndroid Build Coastguard Worker        def fn2(x: Union[List[int], List[str]]) -> str:
316*da0073e9SAndroid Build Coastguard Worker            return "foo"
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        for s in (fn1.graph, fn2.graph):
319*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("x : Union(int[], str[])").run(s)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    def test_union_T_None_is_equivalent_to_optional_T(self):
322*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
323*da0073e9SAndroid Build Coastguard Worker        def inner(x: Union[int, None]) -> int:
324*da0073e9SAndroid Build Coastguard Worker            if x is not None:
325*da0073e9SAndroid Build Coastguard Worker                return x
326*da0073e9SAndroid Build Coastguard Worker            else:
327*da0073e9SAndroid Build Coastguard Worker                return 5
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
330*da0073e9SAndroid Build Coastguard Worker        def fn1() -> int:
331*da0073e9SAndroid Build Coastguard Worker            a: Optional[int] = 5
332*da0073e9SAndroid Build Coastguard Worker            b: Optional[int] = None
333*da0073e9SAndroid Build Coastguard Worker            a_ = inner(a)
334*da0073e9SAndroid Build Coastguard Worker            b_ = inner(b)
335*da0073e9SAndroid Build Coastguard Worker            return a_ + b_
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn1(), 10)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
340*da0073e9SAndroid Build Coastguard Worker        def inner2(x: Optional[int]) -> int:
341*da0073e9SAndroid Build Coastguard Worker            if x is not None:
342*da0073e9SAndroid Build Coastguard Worker                return x
343*da0073e9SAndroid Build Coastguard Worker            else:
344*da0073e9SAndroid Build Coastguard Worker                return 5
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
347*da0073e9SAndroid Build Coastguard Worker        def fn2() -> int:
348*da0073e9SAndroid Build Coastguard Worker            a: Union[int, None] = 5
349*da0073e9SAndroid Build Coastguard Worker            b: Union[int, None] = None
350*da0073e9SAndroid Build Coastguard Worker            a_ = inner(a)
351*da0073e9SAndroid Build Coastguard Worker            b_ = inner(b)
352*da0073e9SAndroid Build Coastguard Worker            return a_ + b_
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn2(), 10)
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    def test_union_optional_of_union_is_flattened(self):
357*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
358*da0073e9SAndroid Build Coastguard Worker        def fn(flag: int) -> Union[str, int, None]:
359*da0073e9SAndroid Build Coastguard Worker            y: Union[int, str, None] = "foo"
360*da0073e9SAndroid Build Coastguard Worker            if flag == 0:
361*da0073e9SAndroid Build Coastguard Worker                x: Optional[Union[int, str]] = y
362*da0073e9SAndroid Build Coastguard Worker            elif flag == 1:
363*da0073e9SAndroid Build Coastguard Worker                x: Optional[Union[int, str]] = 1
364*da0073e9SAndroid Build Coastguard Worker            else:
365*da0073e9SAndroid Build Coastguard Worker                x: Optional[Union[int, str]] = None
366*da0073e9SAndroid Build Coastguard Worker            return x
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        # Can't use `checkScript` because it will flag the fact that
369*da0073e9SAndroid Build Coastguard Worker        # the original code has `Optional[Union[int, str]]` but the
370*da0073e9SAndroid Build Coastguard Worker        # saved/loaded code has `Union[int, NoneType, str]` (even
371*da0073e9SAndroid Build Coastguard Worker        # though this is exactly what we want)
372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(0), "foo")
373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(1), 1)
374*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(2), None)
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO()
377*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(fn, buffer)
378*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO(buffer.getvalue())
379*da0073e9SAndroid Build Coastguard Worker        l = torch.jit.load(buffer)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        s = l.code
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("Union[int, NoneType, str]").check(
384*da0073e9SAndroid Build Coastguard Worker            "Union[int, NoneType, str]"
385*da0073e9SAndroid Build Coastguard Worker        ).run(s)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def test_union_subclasses_larger_union(self):
388*da0073e9SAndroid Build Coastguard Worker        def fn() -> Union[int, str, torch.Tensor]:
389*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str] = "foo"
390*da0073e9SAndroid Build Coastguard Worker            return x
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker    # TODO: We would like to eventually support this. The issue is being
395*da0073e9SAndroid Build Coastguard Worker    # tracked at https://github.com/pytorch/pytorch/issues/58167
396*da0073e9SAndroid Build Coastguard Worker    def test_union_as_dict_key(self):
397*da0073e9SAndroid Build Coastguard Worker        def fn():
398*da0073e9SAndroid Build Coastguard Worker            x: Dict[Union[int, str], str] = {}
399*da0073e9SAndroid Build Coastguard Worker            x["foo"] = "bar"
400*da0073e9SAndroid Build Coastguard Worker            x[1] = 2
401*da0073e9SAndroid Build Coastguard Worker            return x[1]
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
404*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
405*da0073e9SAndroid Build Coastguard Worker            "only int, float, "
406*da0073e9SAndroid Build Coastguard Worker            "complex, Tensor, device and string keys "
407*da0073e9SAndroid Build Coastguard Worker            "are supported",
408*da0073e9SAndroid Build Coastguard Worker        ):
409*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    def test_union_as_dict_value(self):
412*da0073e9SAndroid Build Coastguard Worker        def fn():
413*da0073e9SAndroid Build Coastguard Worker            x: Dict[str, Union[int, str]] = {}
414*da0073e9SAndroid Build Coastguard Worker            x["foo"] = "bar"
415*da0073e9SAndroid Build Coastguard Worker            x["baz"] = 2
416*da0073e9SAndroid Build Coastguard Worker            return x["baz"]
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker    def test_union_module_with_union_instance_variable(self):
421*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
422*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str]
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x: Union[int, str]):
425*da0073e9SAndroid Build Coastguard Worker                super().__init__()
426*da0073e9SAndroid Build Coastguard Worker                self.x: Union[int, str] = x
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker            def forward(self, y: Union[int, str]):
429*da0073e9SAndroid Build Coastguard Worker                self.x = y
430*da0073e9SAndroid Build Coastguard Worker                return self.x
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker        self.checkModule(
433*da0073e9SAndroid Build Coastguard Worker            M(
434*da0073e9SAndroid Build Coastguard Worker                2,
435*da0073e9SAndroid Build Coastguard Worker            ),
436*da0073e9SAndroid Build Coastguard Worker            (1,),
437*da0073e9SAndroid Build Coastguard Worker        )
438*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M("bar"), ("foo",))
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    def test_union_module_with_union_class_variable(self):
441*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
442*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str] = "foo"
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker            def __init__(self, y: int):
445*da0073e9SAndroid Build Coastguard Worker                super().__init__()
446*da0073e9SAndroid Build Coastguard Worker                x = y
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker            def forward(self, z: str):
449*da0073e9SAndroid Build Coastguard Worker                x = z
450*da0073e9SAndroid Build Coastguard Worker                return x
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        self.checkModule(M(1), ("foo",))
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement(self):
455*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, str]) -> str:
456*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, str):
457*da0073e9SAndroid Build Coastguard Worker                z = x + "bar"
458*da0073e9SAndroid Build Coastguard Worker                return x
459*da0073e9SAndroid Build Coastguard Worker            else:
460*da0073e9SAndroid Build Coastguard Worker                return "baz"
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ("foo",))
463*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_union_rhs(self):
466*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
467*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, Union[int, str]):
468*da0073e9SAndroid Build Coastguard Worker                return "bar"
469*da0073e9SAndroid Build Coastguard Worker            else:
470*da0073e9SAndroid Build Coastguard Worker                return "baz"
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_tuple_rhs(self):
475*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, float, List[str]]) -> str:
476*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, (int, float)):
477*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, int):
478*da0073e9SAndroid Build Coastguard Worker                    return str(x)
479*da0073e9SAndroid Build Coastguard Worker                else:
480*da0073e9SAndroid Build Coastguard Worker                    return "foo"
481*da0073e9SAndroid Build Coastguard Worker            else:
482*da0073e9SAndroid Build Coastguard Worker                if len(x):
483*da0073e9SAndroid Build Coastguard Worker                    return x[0]
484*da0073e9SAndroid Build Coastguard Worker                else:
485*da0073e9SAndroid Build Coastguard Worker                    return "bar"
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
488*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1.0,))
489*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (["a", "b", "c"],))
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
492*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[int, List[str]]) -> str:
493*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, (int, float)):
494*da0073e9SAndroid Build Coastguard Worker                y = x + x
495*da0073e9SAndroid Build Coastguard Worker                return str(y)
496*da0073e9SAndroid Build Coastguard Worker            else:
497*da0073e9SAndroid Build Coastguard Worker                if len(x):
498*da0073e9SAndroid Build Coastguard Worker                    return x[0]
499*da0073e9SAndroid Build Coastguard Worker                else:
500*da0073e9SAndroid Build Coastguard Worker                    return "bar"
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
503*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (["a", "b", "c"],))
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_tuple_rhs_union(self):
506*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
507*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
508*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, (Union[int, str], float)):
509*da0073e9SAndroid Build Coastguard Worker                y = x + x
510*da0073e9SAndroid Build Coastguard Worker                return str(y)
511*da0073e9SAndroid Build Coastguard Worker            else:
512*da0073e9SAndroid Build Coastguard Worker                return "foo"
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker        # TODO: There's currently an unrelated bug in
515*da0073e9SAndroid Build Coastguard Worker        # `torch.jit.isinstance` that makes it fail for tuple literals.
516*da0073e9SAndroid Build Coastguard Worker        # Posted here: https://github.com/pytorch/pytorch/issues/60095
517*da0073e9SAndroid Build Coastguard Worker        # Change `assertEqual` to `checkScript` when the bug is fixed
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(1), "2")
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_statically_false(self):
521*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
522*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
523*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, (Union[str, float], List[str], str)):
524*da0073e9SAndroid Build Coastguard Worker                z = x + "foo"
525*da0073e9SAndroid Build Coastguard Worker                return z
526*da0073e9SAndroid Build Coastguard Worker            else:
527*da0073e9SAndroid Build Coastguard Worker                return "bar"
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker        # Check that we don't have any branching statements
532*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("block0()").check_not("block1()").run(s)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_statically_true(self):
535*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
536*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
537*da0073e9SAndroid Build Coastguard Worker            if not torch.jit.isinstance(x, (int, List[int])):
538*da0073e9SAndroid Build Coastguard Worker                return x
539*da0073e9SAndroid Build Coastguard Worker            else:
540*da0073e9SAndroid Build Coastguard Worker                l = [1, 2, 3]
541*da0073e9SAndroid Build Coastguard Worker                y: Union[List[int], int] = l
542*da0073e9SAndroid Build Coastguard Worker                return y
543*da0073e9SAndroid Build Coastguard Worker
544*da0073e9SAndroid Build Coastguard Worker        s = fn.graph
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        # Check that we don't have any branching statements
547*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("block0()").check_not("block1()").run(s)
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
550*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], int]) -> int:
551*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, (int, float, str)):
552*da0073e9SAndroid Build Coastguard Worker                # We should know that `x` is an `int` here
553*da0073e9SAndroid Build Coastguard Worker                z = x + 1
554*da0073e9SAndroid Build Coastguard Worker                return z
555*da0073e9SAndroid Build Coastguard Worker            else:
556*da0073e9SAndroid Build Coastguard Worker                return 100
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
559*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
562*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], int]) -> int:
563*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, Union[int, float, str]):
564*da0073e9SAndroid Build Coastguard Worker                # We should know that `x` is an `int` here
565*da0073e9SAndroid Build Coastguard Worker                z = x + 1
566*da0073e9SAndroid Build Coastguard Worker                return z
567*da0073e9SAndroid Build Coastguard Worker            else:
568*da0073e9SAndroid Build Coastguard Worker                return 100
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
571*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker    def test_union_type_refinement_internal_declaration(self):
574*da0073e9SAndroid Build Coastguard Worker        def fn(flag: bool) -> str:
575*da0073e9SAndroid Build Coastguard Worker            x: Union[int, str, None] = None
576*da0073e9SAndroid Build Coastguard Worker            if flag:
577*da0073e9SAndroid Build Coastguard Worker                y = "foo"
578*da0073e9SAndroid Build Coastguard Worker            else:
579*da0073e9SAndroid Build Coastguard Worker                y = 1
580*da0073e9SAndroid Build Coastguard Worker            if isinstance(x, str):
581*da0073e9SAndroid Build Coastguard Worker                return x
582*da0073e9SAndroid Build Coastguard Worker            else:
583*da0073e9SAndroid Build Coastguard Worker                return "bar"
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (True,))
586*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (False,))
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_union_branching_with_union_return_and_homogenous_types(self):
589*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> Union[int, str]:
590*da0073e9SAndroid Build Coastguard Worker            if x % 2:
591*da0073e9SAndroid Build Coastguard Worker                return "foo"
592*da0073e9SAndroid Build Coastguard Worker            else:
593*da0073e9SAndroid Build Coastguard Worker                return "bar"
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
596*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (8,))
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker    def test_union_branching_does_not_autoinfer_undeclared_union(self):
599*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
600*da0073e9SAndroid Build Coastguard Worker            if x % 2:
601*da0073e9SAndroid Build Coastguard Worker                y = "foo"
602*da0073e9SAndroid Build Coastguard Worker            else:
603*da0073e9SAndroid Build Coastguard Worker                y = x
604*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, str):
605*da0073e9SAndroid Build Coastguard Worker                return y
606*da0073e9SAndroid Build Coastguard Worker            else:
607*da0073e9SAndroid Build Coastguard Worker                return "bar"
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
610*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
611*da0073e9SAndroid Build Coastguard Worker            "y is set to type str"
612*da0073e9SAndroid Build Coastguard Worker            " in the true branch and type int "
613*da0073e9SAndroid Build Coastguard Worker            "in the false branch",
614*da0073e9SAndroid Build Coastguard Worker        ):
615*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker    def test_union_branching_does_not_widen_existing_inferred_type(self):
618*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
619*da0073e9SAndroid Build Coastguard Worker            y = "foo"
620*da0073e9SAndroid Build Coastguard Worker            if x % 2:
621*da0073e9SAndroid Build Coastguard Worker                y = "bar"
622*da0073e9SAndroid Build Coastguard Worker            else:
623*da0073e9SAndroid Build Coastguard Worker                y = x
624*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, str):
625*da0073e9SAndroid Build Coastguard Worker                return y
626*da0073e9SAndroid Build Coastguard Worker            else:
627*da0073e9SAndroid Build Coastguard Worker                return "baz"
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
630*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
631*da0073e9SAndroid Build Coastguard Worker            "previously had type "
632*da0073e9SAndroid Build Coastguard Worker            "str but is now being assigned to a"
633*da0073e9SAndroid Build Coastguard Worker            " value of type int",
634*da0073e9SAndroid Build Coastguard Worker        ):
635*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(fn)
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    def test_union_schema_matching_on_internal_type(self):
638*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], Dict[str, int]]) -> int:
639*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x, List[int]):
640*da0073e9SAndroid Build Coastguard Worker                return x[0]
641*da0073e9SAndroid Build Coastguard Worker            else:
642*da0073e9SAndroid Build Coastguard Worker                return list(x.values())[0]
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
645*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
646*da0073e9SAndroid Build Coastguard Worker
647*da0073e9SAndroid Build Coastguard Worker    def test_union_subtractive_refinement(self):
648*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], int]) -> int:
649*da0073e9SAndroid Build Coastguard Worker            if not isinstance(x, int):
650*da0073e9SAndroid Build Coastguard Worker                x.append(1)
651*da0073e9SAndroid Build Coastguard Worker                return x[0]
652*da0073e9SAndroid Build Coastguard Worker            else:
653*da0073e9SAndroid Build Coastguard Worker                return x
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
656*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    def test_union_subtractive_refinement_with_container(self):
659*da0073e9SAndroid Build Coastguard Worker        def fn(x: Union[List[int], int]) -> int:
660*da0073e9SAndroid Build Coastguard Worker            if not torch.jit.isinstance(x, List[int]):
661*da0073e9SAndroid Build Coastguard Worker                return x
662*da0073e9SAndroid Build Coastguard Worker            else:
663*da0073e9SAndroid Build Coastguard Worker                x.append(1)
664*da0073e9SAndroid Build Coastguard Worker                return x[0]
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
667*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ([1, 2, 3],))
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_union_memory_aliasing(self):
670*da0073e9SAndroid Build Coastguard Worker        def fn():
671*da0073e9SAndroid Build Coastguard Worker            x: List[torch.Tensor] = []
672*da0073e9SAndroid Build Coastguard Worker            z: List[Optional[List[torch.Tensor]]] = []
673*da0073e9SAndroid Build Coastguard Worker            z.append(x)
674*da0073e9SAndroid Build Coastguard Worker            x_alias = z[0]
675*da0073e9SAndroid Build Coastguard Worker            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
676*da0073e9SAndroid Build Coastguard Worker                x_alias.append(torch.tensor(3))
677*da0073e9SAndroid Build Coastguard Worker            return x
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, ())
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    def test_union_serialization_preserves_type_annotations(self):
682*da0073e9SAndroid Build Coastguard Worker        # This function will fail after being torch.jit.save'd and
683*da0073e9SAndroid Build Coastguard Worker        # torch.jit.load'd if the type annotations aren't preserved
684*da0073e9SAndroid Build Coastguard Worker        # for Union during serialization. We need the `Union[str, int]`
685*da0073e9SAndroid Build Coastguard Worker        # annotation to make sure that `y` is typed as a Union instead
686*da0073e9SAndroid Build Coastguard Worker        # of as a str in one branch and an int in the other
687*da0073e9SAndroid Build Coastguard Worker        def fn(x: int) -> str:
688*da0073e9SAndroid Build Coastguard Worker            if x % 2:
689*da0073e9SAndroid Build Coastguard Worker                y: Union[str, int] = "bar"
690*da0073e9SAndroid Build Coastguard Worker            else:
691*da0073e9SAndroid Build Coastguard Worker                y: Union[str, int] = x
692*da0073e9SAndroid Build Coastguard Worker            if isinstance(y, str):
693*da0073e9SAndroid Build Coastguard Worker                return y
694*da0073e9SAndroid Build Coastguard Worker            else:
695*da0073e9SAndroid Build Coastguard Worker                return "baz"
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (1,))
698*da0073e9SAndroid Build Coastguard Worker        self.checkScript(fn, (8,))
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker    def _assert_passes(self, template: str, ann: str, lhs: str):
701*da0073e9SAndroid Build Coastguard Worker        code = template.format(ann=ann, lhs=lhs)
702*da0073e9SAndroid Build Coastguard Worker        self.checkScript(code, (), name="fn")
703*da0073e9SAndroid Build Coastguard Worker
704*da0073e9SAndroid Build Coastguard Worker    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
705*da0073e9SAndroid Build Coastguard Worker        code = template.format(ann=ann, lhs=lhs)
706*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
707*da0073e9SAndroid Build Coastguard Worker            cu = torch.jit.CompilationUnit(code, _frames_up=1)
708*da0073e9SAndroid Build Coastguard Worker            string_frontend = getattr(cu, "fn")  # noqa: B009
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    def test_union_with_list_assignment(self):
711*da0073e9SAndroid Build Coastguard Worker        template = dedent(
712*da0073e9SAndroid Build Coastguard Worker            """
713*da0073e9SAndroid Build Coastguard Worker            def fn():
714*da0073e9SAndroid Build Coastguard Worker                x: {ann} = {lhs}
715*da0073e9SAndroid Build Coastguard Worker                if torch.jit.isinstance(x, List[torch.Tensor]):
716*da0073e9SAndroid Build Coastguard Worker                    x.append(torch.tensor(3))
717*da0073e9SAndroid Build Coastguard Worker                return x
718*da0073e9SAndroid Build Coastguard Worker        """
719*da0073e9SAndroid Build Coastguard Worker        )
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        lhs = {
722*da0073e9SAndroid Build Coastguard Worker            "list_literal_empty": "[]",
723*da0073e9SAndroid Build Coastguard Worker            "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
724*da0073e9SAndroid Build Coastguard Worker            "list_literal_of_str": '["foo", "bar", "baz"]',
725*da0073e9SAndroid Build Coastguard Worker            "list_literal_of_mixed": "[torch.arange(5), 1]",
726*da0073e9SAndroid Build Coastguard Worker            "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
727*da0073e9SAndroid Build Coastguard Worker            "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
728*da0073e9SAndroid Build Coastguard Worker            "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
729*da0073e9SAndroid Build Coastguard Worker        }
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker        """
732*da0073e9SAndroid Build Coastguard Worker        Union[List[str], List[torch.Tensor]]
733*da0073e9SAndroid Build Coastguard Worker        """
734*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
735*da0073e9SAndroid Build Coastguard Worker            template,
736*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
737*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_empty"],
738*da0073e9SAndroid Build Coastguard Worker            "there are multiple possible List type "
739*da0073e9SAndroid Build Coastguard Worker            "candidates in the Union annotation",
740*da0073e9SAndroid Build Coastguard Worker        )
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
743*da0073e9SAndroid Build Coastguard Worker            template,
744*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
745*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_of_tensor"],
746*da0073e9SAndroid Build Coastguard Worker        )
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
749*da0073e9SAndroid Build Coastguard Worker            template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
750*da0073e9SAndroid Build Coastguard Worker        )
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
753*da0073e9SAndroid Build Coastguard Worker            template,
754*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
755*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_of_mixed"],
756*da0073e9SAndroid Build Coastguard Worker            "none of those types match the types of the" " given list elements",
757*da0073e9SAndroid Build Coastguard Worker        )
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
760*da0073e9SAndroid Build Coastguard Worker            template,
761*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
762*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_tensor"],
763*da0073e9SAndroid Build Coastguard Worker        )
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
766*da0073e9SAndroid Build Coastguard Worker            template,
767*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
768*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_str"],
769*da0073e9SAndroid Build Coastguard Worker        )
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker        # TODO: Support mixed list comprehensions
772*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
773*da0073e9SAndroid Build Coastguard Worker            template,
774*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
775*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_mixed"],
776*da0073e9SAndroid Build Coastguard Worker            "Arguments for call are not valid",
777*da0073e9SAndroid Build Coastguard Worker        )
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        """
780*da0073e9SAndroid Build Coastguard Worker        Union[int, torch.Tensor]
781*da0073e9SAndroid Build Coastguard Worker        """
782*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
783*da0073e9SAndroid Build Coastguard Worker            template,
784*da0073e9SAndroid Build Coastguard Worker            "Union[int, torch.Tensor]",
785*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_empty"],
786*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with an " "inner List type",
787*da0073e9SAndroid Build Coastguard Worker        )
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
790*da0073e9SAndroid Build Coastguard Worker            template,
791*da0073e9SAndroid Build Coastguard Worker            "Union[int, torch.Tensor]",
792*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_of_tensor"],
793*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with an " "inner List type",
794*da0073e9SAndroid Build Coastguard Worker        )
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
797*da0073e9SAndroid Build Coastguard Worker            template,
798*da0073e9SAndroid Build Coastguard Worker            "Union[int, torch.Tensor]",
799*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_tensor"],
800*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with an " "inner List type",
801*da0073e9SAndroid Build Coastguard Worker        )
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker        """
804*da0073e9SAndroid Build Coastguard Worker        Union[List[torch.Tensor], int]
805*da0073e9SAndroid Build Coastguard Worker        """
806*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
807*da0073e9SAndroid Build Coastguard Worker            template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
808*da0073e9SAndroid Build Coastguard Worker        )
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
811*da0073e9SAndroid Build Coastguard Worker            template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
812*da0073e9SAndroid Build Coastguard Worker        )
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
815*da0073e9SAndroid Build Coastguard Worker            template,
816*da0073e9SAndroid Build Coastguard Worker            "Union[List[torch.Tensor], int]",
817*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_of_str"],
818*da0073e9SAndroid Build Coastguard Worker            r"List type annotation `List\[Tensor\]` did "
819*da0073e9SAndroid Build Coastguard Worker            "not match the types of the given list "
820*da0073e9SAndroid Build Coastguard Worker            "elements",
821*da0073e9SAndroid Build Coastguard Worker        )
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
824*da0073e9SAndroid Build Coastguard Worker            template,
825*da0073e9SAndroid Build Coastguard Worker            "Union[List[torch.Tensor], int]",
826*da0073e9SAndroid Build Coastguard Worker            lhs["list_literal_of_mixed"],
827*da0073e9SAndroid Build Coastguard Worker            r"List type annotation `List\[Tensor\]` did "
828*da0073e9SAndroid Build Coastguard Worker            "not match the types of the given list "
829*da0073e9SAndroid Build Coastguard Worker            "elements",
830*da0073e9SAndroid Build Coastguard Worker        )
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
833*da0073e9SAndroid Build Coastguard Worker            template,
834*da0073e9SAndroid Build Coastguard Worker            "Union[List[torch.Tensor], int]",
835*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_tensor"],
836*da0073e9SAndroid Build Coastguard Worker        )
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
839*da0073e9SAndroid Build Coastguard Worker            template,
840*da0073e9SAndroid Build Coastguard Worker            "Union[List[torch.Tensor], int]",
841*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_str"],
842*da0073e9SAndroid Build Coastguard Worker            r"List type annotation `List\[Tensor\]` did "
843*da0073e9SAndroid Build Coastguard Worker            "not match the types of the given list "
844*da0073e9SAndroid Build Coastguard Worker            "elements",
845*da0073e9SAndroid Build Coastguard Worker        )
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker        # TODO(@ansley): Support mixed list comprehensions
848*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
849*da0073e9SAndroid Build Coastguard Worker            template,
850*da0073e9SAndroid Build Coastguard Worker            "Union[List[torch.Tensor], int]",
851*da0073e9SAndroid Build Coastguard Worker            lhs["list_comprehension_of_mixed"],
852*da0073e9SAndroid Build Coastguard Worker            "Arguments for call are not valid",
853*da0073e9SAndroid Build Coastguard Worker        )
854*da0073e9SAndroid Build Coastguard Worker
855*da0073e9SAndroid Build Coastguard Worker    def test_union_with_dict_assignment(self):
856*da0073e9SAndroid Build Coastguard Worker        template = dedent(
857*da0073e9SAndroid Build Coastguard Worker            """
858*da0073e9SAndroid Build Coastguard Worker            def fn():
859*da0073e9SAndroid Build Coastguard Worker                x: {ann} = {lhs}
860*da0073e9SAndroid Build Coastguard Worker                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
861*da0073e9SAndroid Build Coastguard Worker                    x["foo"] = torch.tensor(3)
862*da0073e9SAndroid Build Coastguard Worker                return x
863*da0073e9SAndroid Build Coastguard Worker        """
864*da0073e9SAndroid Build Coastguard Worker        )
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker        lhs = {
867*da0073e9SAndroid Build Coastguard Worker            "dict_literal_empty": "{}",
868*da0073e9SAndroid Build Coastguard Worker            "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
869*da0073e9SAndroid Build Coastguard Worker            "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
870*da0073e9SAndroid Build Coastguard Worker            "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
871*da0073e9SAndroid Build Coastguard Worker            "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
872*da0073e9SAndroid Build Coastguard Worker                    zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
873*da0073e9SAndroid Build Coastguard Worker            "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
874*da0073e9SAndroid Build Coastguard Worker                    zip(["foo", "bar"], [1, 2]}',
875*da0073e9SAndroid Build Coastguard Worker            "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
876*da0073e9SAndroid Build Coastguard Worker                    zip(["foo", "bar"], [torch.arange(3), 2])}',
877*da0073e9SAndroid Build Coastguard Worker            "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
878*da0073e9SAndroid Build Coastguard Worker            "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
879*da0073e9SAndroid Build Coastguard Worker            "dict_keyword_with_empty_iterable": "dict([])",
880*da0073e9SAndroid Build Coastguard Worker            "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
881*da0073e9SAndroid Build Coastguard Worker            "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
882*da0073e9SAndroid Build Coastguard Worker            "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
883*da0073e9SAndroid Build Coastguard Worker        }
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker        """
886*da0073e9SAndroid Build Coastguard Worker        Union[Dict[str, torch.Tensor], Dict[str, int]]
887*da0073e9SAndroid Build Coastguard Worker        """
888*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
889*da0073e9SAndroid Build Coastguard Worker            template,
890*da0073e9SAndroid Build Coastguard Worker            "Union[List[str], List[torch.Tensor]]",
891*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_empty"],
892*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with an " "inner Dict type",
893*da0073e9SAndroid Build Coastguard Worker        )
894*da0073e9SAndroid Build Coastguard Worker
895*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
896*da0073e9SAndroid Build Coastguard Worker            template,
897*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
898*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_str_tensor"],
899*da0073e9SAndroid Build Coastguard Worker        )
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
902*da0073e9SAndroid Build Coastguard Worker            template,
903*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
904*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_str_int"],
905*da0073e9SAndroid Build Coastguard Worker        )
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
908*da0073e9SAndroid Build Coastguard Worker            template,
909*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
910*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_mixed"],
911*da0073e9SAndroid Build Coastguard Worker            "none of those dict types can hold the "
912*da0073e9SAndroid Build Coastguard Worker            "types of the given keys and values",
913*da0073e9SAndroid Build Coastguard Worker        )
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker        # TODO: String frontend does not support tuple unpacking
916*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/64096
917*da0073e9SAndroid Build Coastguard Worker        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
918*da0073e9SAndroid Build Coastguard Worker        #              lhs["dict_comprehension_of_str_tensor"])
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
921*da0073e9SAndroid Build Coastguard Worker        #              lhs["dict_comprehension_of_str_int"])
922*da0073e9SAndroid Build Coastguard Worker
923*da0073e9SAndroid Build Coastguard Worker        # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
924*da0073e9SAndroid Build Coastguard Worker        #              lhs["dict_comprehension_of_mixed"],
925*da0073e9SAndroid Build Coastguard Worker        #              "foobar")
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker        # self._assert_passes(template,
928*da0073e9SAndroid Build Coastguard Worker        #                    "Union[Dict[str, torch.Tensor], Dict[str, int]]",
929*da0073e9SAndroid Build Coastguard Worker        #                    lhs["dict_keyword_with_internal_aggregate_function"])
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Worker        # TODO(@ansley): Follow-up project needed for full type
932*da0073e9SAndroid Build Coastguard Worker        # inference with dict keyword (supported for dict comprehension
933*da0073e9SAndroid Build Coastguard Worker        # and dict literal already; should not be a blocker for anyone)
934*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
935*da0073e9SAndroid Build Coastguard Worker            template,
936*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
937*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword"],
938*da0073e9SAndroid Build Coastguard Worker            "full type inference is not yet supported",
939*da0073e9SAndroid Build Coastguard Worker        )
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
942*da0073e9SAndroid Build Coastguard Worker            template,
943*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
944*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_iterable"],
945*da0073e9SAndroid Build Coastguard Worker            "full type inference is not yet supported",
946*da0073e9SAndroid Build Coastguard Worker        )
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
949*da0073e9SAndroid Build Coastguard Worker            template,
950*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
951*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_empty_iterable"],
952*da0073e9SAndroid Build Coastguard Worker            "full type inference is not yet supported",
953*da0073e9SAndroid Build Coastguard Worker        )
954*da0073e9SAndroid Build Coastguard Worker
955*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
956*da0073e9SAndroid Build Coastguard Worker            template,
957*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
958*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_mapping"],
959*da0073e9SAndroid Build Coastguard Worker            "full type inference is not yet supported",
960*da0073e9SAndroid Build Coastguard Worker        )
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
963*da0073e9SAndroid Build Coastguard Worker            template,
964*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], Dict[str, int]]",
965*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_mapping_and_kwargs"],
966*da0073e9SAndroid Build Coastguard Worker            "full type inference is not yet supported",
967*da0073e9SAndroid Build Coastguard Worker        )
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker        """
970*da0073e9SAndroid Build Coastguard Worker        Union[int, torch.Tensor]
971*da0073e9SAndroid Build Coastguard Worker        """
972*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
973*da0073e9SAndroid Build Coastguard Worker            template,
974*da0073e9SAndroid Build Coastguard Worker            "Union[int, torch.Tensor]",
975*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_empty"],
976*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with " "an inner Dict type",
977*da0073e9SAndroid Build Coastguard Worker        )
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
980*da0073e9SAndroid Build Coastguard Worker            template,
981*da0073e9SAndroid Build Coastguard Worker            "Union[int, torch.Tensor]",
982*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_str_tensor"],
983*da0073e9SAndroid Build Coastguard Worker            "Expected an Union type annotation with " "an inner Dict type",
984*da0073e9SAndroid Build Coastguard Worker        )
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker        # See above--string frontend does not support tuple unpacking
987*da0073e9SAndroid Build Coastguard Worker        # self._assert_raises(template, "Union[int, torch.Tensor]",
988*da0073e9SAndroid Build Coastguard Worker        #              lhs["dict_comprehension_of_tensor"],
989*da0073e9SAndroid Build Coastguard Worker        #              "foobar")
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker        """
992*da0073e9SAndroid Build Coastguard Worker        Union[Dict[str, torch.Tensor], int]
993*da0073e9SAndroid Build Coastguard Worker        """
994*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
995*da0073e9SAndroid Build Coastguard Worker            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
996*da0073e9SAndroid Build Coastguard Worker        )
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
999*da0073e9SAndroid Build Coastguard Worker            template,
1000*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1001*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_str_tensor"],
1002*da0073e9SAndroid Build Coastguard Worker        )
1003*da0073e9SAndroid Build Coastguard Worker
1004*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
1005*da0073e9SAndroid Build Coastguard Worker            template,
1006*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1007*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_str_int"],
1008*da0073e9SAndroid Build Coastguard Worker            "Type annotation was inferred to be "
1009*da0073e9SAndroid Build Coastguard Worker            r"`Dict\[str, Tensor\]`, but the type of "
1010*da0073e9SAndroid Build Coastguard Worker            "values given by the dict literal is",
1011*da0073e9SAndroid Build Coastguard Worker        )
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker        self._assert_raises(
1014*da0073e9SAndroid Build Coastguard Worker            template,
1015*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1016*da0073e9SAndroid Build Coastguard Worker            lhs["dict_literal_of_mixed"],
1017*da0073e9SAndroid Build Coastguard Worker            "Type annotation was inferred to be "
1018*da0073e9SAndroid Build Coastguard Worker            r"`Dict\[str, Tensor\]`, but the type of "
1019*da0073e9SAndroid Build Coastguard Worker            "values given by the dict literal is",
1020*da0073e9SAndroid Build Coastguard Worker        )
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
1023*da0073e9SAndroid Build Coastguard Worker            template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
1024*da0073e9SAndroid Build Coastguard Worker        )
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
1027*da0073e9SAndroid Build Coastguard Worker            template,
1028*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1029*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_iterable"],
1030*da0073e9SAndroid Build Coastguard Worker        )
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
1033*da0073e9SAndroid Build Coastguard Worker            template,
1034*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1035*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_empty_iterable"],
1036*da0073e9SAndroid Build Coastguard Worker        )
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
1039*da0073e9SAndroid Build Coastguard Worker            template,
1040*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1041*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_mapping"],
1042*da0073e9SAndroid Build Coastguard Worker        )
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        self._assert_passes(
1045*da0073e9SAndroid Build Coastguard Worker            template,
1046*da0073e9SAndroid Build Coastguard Worker            "Union[Dict[str, torch.Tensor], int]",
1047*da0073e9SAndroid Build Coastguard Worker            lhs["dict_keyword_with_mapping_and_kwargs"],
1048*da0073e9SAndroid Build Coastguard Worker        )
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker        # See above--string frontend does not support tuple unpacking
1051*da0073e9SAndroid Build Coastguard Worker        # self._assert_passes(template,
1052*da0073e9SAndroid Build Coastguard Worker        #                    "Union[Dict[str, torch.Tensor], int]",
1053*da0073e9SAndroid Build Coastguard Worker        #                    lhs["dict_keyword_with_internal_aggregate_function"])
1054*da0073e9SAndroid Build Coastguard Worker        #
1055*da0073e9SAndroid Build Coastguard Worker        # self._assert_passes(template,
1056*da0073e9SAndroid Build Coastguard Worker        #                    "Union[Dict[str, torch.Tensor], int]",
1057*da0073e9SAndroid Build Coastguard Worker        #                    lhs["dict_comprehension_of_str_tensor"])
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        # self._assert_raises(template,
1060*da0073e9SAndroid Build Coastguard Worker        #                    "Union[Dict[str, torch.Tensor], int]",
1061*da0073e9SAndroid Build Coastguard Worker        #                    lhs["dict_comprehension_of_str_int"],
1062*da0073e9SAndroid Build Coastguard Worker        #                    "foobar")
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker        # self._assert_raises(template,
1065*da0073e9SAndroid Build Coastguard Worker        #                    "Union[Dict[str, torch.Tensor], int]",
1066*da0073e9SAndroid Build Coastguard Worker        #                    lhs["dict_comprehension_of_mixed"],
1067*da0073e9SAndroid Build Coastguard Worker        #                    "foobar")
1068