1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 7*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 8*da0073e9SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple, Union 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 15*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, make_global 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 21*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 22*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 23*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 24*da0073e9SAndroid Build Coastguard Worker "instead." 25*da0073e9SAndroid Build Coastguard Worker ) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass TestUnion(JitTestCase): 29*da0073e9SAndroid Build Coastguard Worker """ 30*da0073e9SAndroid Build Coastguard Worker This class tests the functionality of `Union`. 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker Note: It's important to be able to refine the type of a `Union` to 33*da0073e9SAndroid Build Coastguard Worker one of its internal types. Currently, there are differences in the 34*da0073e9SAndroid Build Coastguard Worker way Python expects `isinstance` checks and the way TorchScript 35*da0073e9SAndroid Build Coastguard Worker expects `isinstance` checks. This means that we can't use 36*da0073e9SAndroid Build Coastguard Worker `checkScript` in our test cases because either the eager mode or the 37*da0073e9SAndroid Build Coastguard Worker script mode wouldn't run! So, some test cases have separate but 38*da0073e9SAndroid Build Coastguard Worker equivalent functions to emulate `checkScript`. 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_check_union_annotation(self): 42*da0073e9SAndroid Build Coastguard Worker def test_func(a: Union[int, float], b: Optional[int]): 43*da0073e9SAndroid Build Coastguard Worker return 0 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker scripted_func = torch.jit.script(test_func) 46*da0073e9SAndroid Build Coastguard Worker graph_rep = str(scripted_func.graph) 47*da0073e9SAndroid Build Coastguard Worker code_rep = str(scripted_func.code) 48*da0073e9SAndroid Build Coastguard Worker # TS graph IR for Union should be annotated as Union() 49*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Union(").check("int?").run(graph_rep) 50*da0073e9SAndroid Build Coastguard Worker # Serialized code for Union should be annotated as Union[] 51*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Union[").check("Optional[int]").run(code_rep) 52*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_func, (5, 6)) 53*da0073e9SAndroid Build Coastguard Worker # this shouldn't error out 54*da0073e9SAndroid Build Coastguard Worker torch._C.parse_ir(str(scripted_func.graph)) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def test_union_with_scalar_values(self): 57*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, float]) -> str: 58*da0073e9SAndroid Build Coastguard Worker return "foo" 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 61*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.0,)) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 66*da0073e9SAndroid Build Coastguard Worker RuntimeError, 67*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 68*da0073e9SAndroid Build Coastguard Worker r" Union\[float, int\] but " 69*da0073e9SAndroid Build Coastguard Worker "instead found type str", 70*da0073e9SAndroid Build Coastguard Worker ): 71*da0073e9SAndroid Build Coastguard Worker scripted("1") 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def test_union_with_collections(self): 74*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[Dict[str, int], List[int]]) -> str: 75*da0073e9SAndroid Build Coastguard Worker return "foo" 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 78*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 83*da0073e9SAndroid Build Coastguard Worker RuntimeError, 84*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 85*da0073e9SAndroid Build Coastguard Worker r" Union\[List\[int\], Dict\[str, " 86*da0073e9SAndroid Build Coastguard Worker r"int\]\] but instead found type " 87*da0073e9SAndroid Build Coastguard Worker r"Dict\[str, str\]", 88*da0073e9SAndroid Build Coastguard Worker ): 89*da0073e9SAndroid Build Coastguard Worker scripted({"foo": "bar", "baz": "qux"}) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 92*da0073e9SAndroid Build Coastguard Worker RuntimeError, 93*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 94*da0073e9SAndroid Build Coastguard Worker r" Union\[List\[int\], Dict\[str, " 95*da0073e9SAndroid Build Coastguard Worker r"int\]\] but instead found type " 96*da0073e9SAndroid Build Coastguard Worker r"List\[str\]", 97*da0073e9SAndroid Build Coastguard Worker ): 98*da0073e9SAndroid Build Coastguard Worker scripted(["foo", "bar", "baz"]) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 101*da0073e9SAndroid Build Coastguard Worker RuntimeError, 102*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 103*da0073e9SAndroid Build Coastguard Worker r" Union\[List\[int\], Dict\[str, " 104*da0073e9SAndroid Build Coastguard Worker r"int\]\] but instead found type " 105*da0073e9SAndroid Build Coastguard Worker "str", 106*da0073e9SAndroid Build Coastguard Worker ): 107*da0073e9SAndroid Build Coastguard Worker scripted("1") 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def test_union_with_enum(self): 110*da0073e9SAndroid Build Coastguard Worker class Color(Enum): 111*da0073e9SAndroid Build Coastguard Worker RED = 1 112*da0073e9SAndroid Build Coastguard Worker GREEN = 2 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker make_global(Color) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[str, Color]) -> str: 117*da0073e9SAndroid Build Coastguard Worker return "foo" 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (Color.RED,)) 120*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("red",)) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 125*da0073e9SAndroid Build Coastguard Worker RuntimeError, 126*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 127*da0073e9SAndroid Build Coastguard Worker r" Union\[__torch__.jit.test_union." 128*da0073e9SAndroid Build Coastguard Worker r"Color, str\] but instead found " 129*da0073e9SAndroid Build Coastguard Worker "type int", 130*da0073e9SAndroid Build Coastguard Worker ): 131*da0073e9SAndroid Build Coastguard Worker scripted(1) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def test_union_in_class_constructor(self): 134*da0073e9SAndroid Build Coastguard Worker @torch.jit.script # noqa: B903 135*da0073e9SAndroid Build Coastguard Worker class A: # noqa: B903 136*da0073e9SAndroid Build Coastguard Worker def __init__(self, x: Union[int, str]) -> None: 137*da0073e9SAndroid Build Coastguard Worker self.x = x 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[str, int]) -> A: 140*da0073e9SAndroid Build Coastguard Worker return A(x) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn("foo").x, "foo") 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(1).x, 1) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 148*da0073e9SAndroid Build Coastguard Worker RuntimeError, 149*da0073e9SAndroid Build Coastguard Worker "Expected a member of" 150*da0073e9SAndroid Build Coastguard Worker r" Union\[int, str\] but instead " 151*da0073e9SAndroid Build Coastguard Worker r"found type List\[str\]", 152*da0073e9SAndroid Build Coastguard Worker ): 153*da0073e9SAndroid Build Coastguard Worker scripted(["foo", "bar", "baz"]) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def test_union_return_type(self): 156*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> Union[int, str]: 157*da0073e9SAndroid Build Coastguard Worker return "foo" 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def test_union_as_annotation(self): 162*da0073e9SAndroid Build Coastguard Worker def fn() -> Union[int, str]: 163*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] = "foo" 164*da0073e9SAndroid Build Coastguard Worker return x 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker def test_union_as_annotation_in_typed_container(self): 169*da0073e9SAndroid Build Coastguard Worker def fn() -> None: 170*da0073e9SAndroid Build Coastguard Worker l: List[Union[int, str]] = [] 171*da0073e9SAndroid Build Coastguard Worker u1: Union[int, str] = "foo" 172*da0073e9SAndroid Build Coastguard Worker u2: Union[int, str] = 1 173*da0073e9SAndroid Build Coastguard Worker l.append(u1) 174*da0073e9SAndroid Build Coastguard Worker l.append(u2) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def test_union_as_annotation_py2(self): 179*da0073e9SAndroid Build Coastguard Worker def fn(): 180*da0073e9SAndroid Build Coastguard Worker # type: () -> Union[int, str] 181*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] = "foo" 182*da0073e9SAndroid Build Coastguard Worker return x 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def test_union_as_internal_tuple_type(self): 187*da0073e9SAndroid Build Coastguard Worker def fn(): 188*da0073e9SAndroid Build Coastguard Worker t: Tuple[Union[int, str], Union[int, str]] = (1, "foo") 189*da0073e9SAndroid Build Coastguard Worker return t 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_union_variable_can_be_reassigned(self): 194*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 195*da0073e9SAndroid Build Coastguard Worker def aux1(i: int): 196*da0073e9SAndroid Build Coastguard Worker return int(i**2) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 199*da0073e9SAndroid Build Coastguard Worker def aux2(s: str): 200*da0073e9SAndroid Build Coastguard Worker return s + s 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker def fn() -> Union[int, str]: 203*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] = "foo" 204*da0073e9SAndroid Build Coastguard Worker i: int = 1 205*da0073e9SAndroid Build Coastguard Worker x = i 206*da0073e9SAndroid Build Coastguard Worker y: int = aux1(x) 207*da0073e9SAndroid Build Coastguard Worker z: str = aux2(str(y)) 208*da0073e9SAndroid Build Coastguard Worker x = z 209*da0073e9SAndroid Build Coastguard Worker return x 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker def test_union_does_not_replace_existing_annotated_type(self): 214*da0073e9SAndroid Build Coastguard Worker def fn(): 215*da0073e9SAndroid Build Coastguard Worker x: List[int] = [1, 2, 3] 216*da0073e9SAndroid Build Coastguard Worker x.append("foo") 217*da0073e9SAndroid Build Coastguard Worker return x 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 220*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 221*da0073e9SAndroid Build Coastguard Worker scripted() 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def test_union_does_not_replace_existing_annotated_type_union(self): 224*da0073e9SAndroid Build Coastguard Worker def fn(): 225*da0073e9SAndroid Build Coastguard Worker x: List[Union[int, str]] = [1, "foo", 3] 226*da0073e9SAndroid Build Coastguard Worker x.append(2.0) 227*da0073e9SAndroid Build Coastguard Worker return x 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Could not match type float"): 230*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 231*da0073e9SAndroid Build Coastguard Worker scripted() 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker def test_union_does_not_replace_existing_annotated_type_empty_container(self): 234*da0073e9SAndroid Build Coastguard Worker def fn(): 235*da0073e9SAndroid Build Coastguard Worker x: List[int] = [] 236*da0073e9SAndroid Build Coastguard Worker x.append("foo") 237*da0073e9SAndroid Build Coastguard Worker return x 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 240*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(fn) 241*da0073e9SAndroid Build Coastguard Worker scripted() 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def test_unions_of_unions_are_flattened(self): 244*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 245*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[Union[int, str], float]) -> str: 246*da0073e9SAndroid Build Coastguard Worker return "foo" 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker s = fn.graph 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(float, int, str)").run(s) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker def test_unions_of_a_single_argument_vanish(self): 253*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 254*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int]) -> str: 255*da0073e9SAndroid Build Coastguard Worker return "foo" 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker s = fn.graph 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : int").run(s) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker def test_union_redundant_arguments_are_skipped(self): 262*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 263*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, str, int]) -> str: 264*da0073e9SAndroid Build Coastguard Worker return "foo" 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker s = fn.graph 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(int, str)").run(s) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def test_union_redundant_arguments_are_skipped_optional(self): 271*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 272*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, Optional[float], Optional[int]]) -> str: 273*da0073e9SAndroid Build Coastguard Worker return "foo" 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker s = fn.graph 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(float, int, NoneType)").run(s) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def test_union_redundant_arguments_are_skipped_subtyping(self): 280*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 281*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str: 282*da0073e9SAndroid Build Coastguard Worker return "foo" 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker s = fn.graph 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union((int?, int), str)").run(s) 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker def test_union_redundant_arguments_are_skipped_container(self): 289*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 290*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[str], List[float], List[str]]) -> str: 291*da0073e9SAndroid Build Coastguard Worker return "foo" 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker s = fn.graph 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(float[], str[])").run(s) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_union_argument_order_is_ignored(self): 298*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 299*da0073e9SAndroid Build Coastguard Worker def fn1(x: Union[int, str]) -> str: 300*da0073e9SAndroid Build Coastguard Worker return "foo" 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 303*da0073e9SAndroid Build Coastguard Worker def fn2(x: Union[str, int]) -> str: 304*da0073e9SAndroid Build Coastguard Worker return "foo" 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker for s in (fn1.graph, fn2.graph): 307*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(int, str)").run(s) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def test_union_argument_order_is_ignored_container(self): 310*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 311*da0073e9SAndroid Build Coastguard Worker def fn1(x: Union[List[str], List[int]]) -> str: 312*da0073e9SAndroid Build Coastguard Worker return "foo" 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 315*da0073e9SAndroid Build Coastguard Worker def fn2(x: Union[List[int], List[str]]) -> str: 316*da0073e9SAndroid Build Coastguard Worker return "foo" 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker for s in (fn1.graph, fn2.graph): 319*da0073e9SAndroid Build Coastguard Worker FileCheck().check("x : Union(int[], str[])").run(s) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def test_union_T_None_is_equivalent_to_optional_T(self): 322*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 323*da0073e9SAndroid Build Coastguard Worker def inner(x: Union[int, None]) -> int: 324*da0073e9SAndroid Build Coastguard Worker if x is not None: 325*da0073e9SAndroid Build Coastguard Worker return x 326*da0073e9SAndroid Build Coastguard Worker else: 327*da0073e9SAndroid Build Coastguard Worker return 5 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 330*da0073e9SAndroid Build Coastguard Worker def fn1() -> int: 331*da0073e9SAndroid Build Coastguard Worker a: Optional[int] = 5 332*da0073e9SAndroid Build Coastguard Worker b: Optional[int] = None 333*da0073e9SAndroid Build Coastguard Worker a_ = inner(a) 334*da0073e9SAndroid Build Coastguard Worker b_ = inner(b) 335*da0073e9SAndroid Build Coastguard Worker return a_ + b_ 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn1(), 10) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 340*da0073e9SAndroid Build Coastguard Worker def inner2(x: Optional[int]) -> int: 341*da0073e9SAndroid Build Coastguard Worker if x is not None: 342*da0073e9SAndroid Build Coastguard Worker return x 343*da0073e9SAndroid Build Coastguard Worker else: 344*da0073e9SAndroid Build Coastguard Worker return 5 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 347*da0073e9SAndroid Build Coastguard Worker def fn2() -> int: 348*da0073e9SAndroid Build Coastguard Worker a: Union[int, None] = 5 349*da0073e9SAndroid Build Coastguard Worker b: Union[int, None] = None 350*da0073e9SAndroid Build Coastguard Worker a_ = inner(a) 351*da0073e9SAndroid Build Coastguard Worker b_ = inner(b) 352*da0073e9SAndroid Build Coastguard Worker return a_ + b_ 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn2(), 10) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker def test_union_optional_of_union_is_flattened(self): 357*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 358*da0073e9SAndroid Build Coastguard Worker def fn(flag: int) -> Union[str, int, None]: 359*da0073e9SAndroid Build Coastguard Worker y: Union[int, str, None] = "foo" 360*da0073e9SAndroid Build Coastguard Worker if flag == 0: 361*da0073e9SAndroid Build Coastguard Worker x: Optional[Union[int, str]] = y 362*da0073e9SAndroid Build Coastguard Worker elif flag == 1: 363*da0073e9SAndroid Build Coastguard Worker x: Optional[Union[int, str]] = 1 364*da0073e9SAndroid Build Coastguard Worker else: 365*da0073e9SAndroid Build Coastguard Worker x: Optional[Union[int, str]] = None 366*da0073e9SAndroid Build Coastguard Worker return x 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker # Can't use `checkScript` because it will flag the fact that 369*da0073e9SAndroid Build Coastguard Worker # the original code has `Optional[Union[int, str]]` but the 370*da0073e9SAndroid Build Coastguard Worker # saved/loaded code has `Union[int, NoneType, str]` (even 371*da0073e9SAndroid Build Coastguard Worker # though this is exactly what we want) 372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(0), "foo") 373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(1), 1) 374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(2), None) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 377*da0073e9SAndroid Build Coastguard Worker torch.jit.save(fn, buffer) 378*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO(buffer.getvalue()) 379*da0073e9SAndroid Build Coastguard Worker l = torch.jit.load(buffer) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker s = l.code 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Union[int, NoneType, str]").check( 384*da0073e9SAndroid Build Coastguard Worker "Union[int, NoneType, str]" 385*da0073e9SAndroid Build Coastguard Worker ).run(s) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def test_union_subclasses_larger_union(self): 388*da0073e9SAndroid Build Coastguard Worker def fn() -> Union[int, str, torch.Tensor]: 389*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] = "foo" 390*da0073e9SAndroid Build Coastguard Worker return x 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker # TODO: We would like to eventually support this. The issue is being 395*da0073e9SAndroid Build Coastguard Worker # tracked at https://github.com/pytorch/pytorch/issues/58167 396*da0073e9SAndroid Build Coastguard Worker def test_union_as_dict_key(self): 397*da0073e9SAndroid Build Coastguard Worker def fn(): 398*da0073e9SAndroid Build Coastguard Worker x: Dict[Union[int, str], str] = {} 399*da0073e9SAndroid Build Coastguard Worker x["foo"] = "bar" 400*da0073e9SAndroid Build Coastguard Worker x[1] = 2 401*da0073e9SAndroid Build Coastguard Worker return x[1] 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 404*da0073e9SAndroid Build Coastguard Worker RuntimeError, 405*da0073e9SAndroid Build Coastguard Worker "only int, float, " 406*da0073e9SAndroid Build Coastguard Worker "complex, Tensor, device and string keys " 407*da0073e9SAndroid Build Coastguard Worker "are supported", 408*da0073e9SAndroid Build Coastguard Worker ): 409*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker def test_union_as_dict_value(self): 412*da0073e9SAndroid Build Coastguard Worker def fn(): 413*da0073e9SAndroid Build Coastguard Worker x: Dict[str, Union[int, str]] = {} 414*da0073e9SAndroid Build Coastguard Worker x["foo"] = "bar" 415*da0073e9SAndroid Build Coastguard Worker x["baz"] = 2 416*da0073e9SAndroid Build Coastguard Worker return x["baz"] 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker def test_union_module_with_union_instance_variable(self): 421*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 422*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker def __init__(self, x: Union[int, str]): 425*da0073e9SAndroid Build Coastguard Worker super().__init__() 426*da0073e9SAndroid Build Coastguard Worker self.x: Union[int, str] = x 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker def forward(self, y: Union[int, str]): 429*da0073e9SAndroid Build Coastguard Worker self.x = y 430*da0073e9SAndroid Build Coastguard Worker return self.x 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker self.checkModule( 433*da0073e9SAndroid Build Coastguard Worker M( 434*da0073e9SAndroid Build Coastguard Worker 2, 435*da0073e9SAndroid Build Coastguard Worker ), 436*da0073e9SAndroid Build Coastguard Worker (1,), 437*da0073e9SAndroid Build Coastguard Worker ) 438*da0073e9SAndroid Build Coastguard Worker self.checkModule(M("bar"), ("foo",)) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker def test_union_module_with_union_class_variable(self): 441*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 442*da0073e9SAndroid Build Coastguard Worker x: Union[int, str] = "foo" 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker def __init__(self, y: int): 445*da0073e9SAndroid Build Coastguard Worker super().__init__() 446*da0073e9SAndroid Build Coastguard Worker x = y 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker def forward(self, z: str): 449*da0073e9SAndroid Build Coastguard Worker x = z 450*da0073e9SAndroid Build Coastguard Worker return x 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(1), ("foo",)) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement(self): 455*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, str]) -> str: 456*da0073e9SAndroid Build Coastguard Worker if isinstance(x, str): 457*da0073e9SAndroid Build Coastguard Worker z = x + "bar" 458*da0073e9SAndroid Build Coastguard Worker return x 459*da0073e9SAndroid Build Coastguard Worker else: 460*da0073e9SAndroid Build Coastguard Worker return "baz" 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("foo",)) 463*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_union_rhs(self): 466*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 467*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, Union[int, str]): 468*da0073e9SAndroid Build Coastguard Worker return "bar" 469*da0073e9SAndroid Build Coastguard Worker else: 470*da0073e9SAndroid Build Coastguard Worker return "baz" 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_tuple_rhs(self): 475*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, float, List[str]]) -> str: 476*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (int, float)): 477*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int): 478*da0073e9SAndroid Build Coastguard Worker return str(x) 479*da0073e9SAndroid Build Coastguard Worker else: 480*da0073e9SAndroid Build Coastguard Worker return "foo" 481*da0073e9SAndroid Build Coastguard Worker else: 482*da0073e9SAndroid Build Coastguard Worker if len(x): 483*da0073e9SAndroid Build Coastguard Worker return x[0] 484*da0073e9SAndroid Build Coastguard Worker else: 485*da0073e9SAndroid Build Coastguard Worker return "bar" 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 488*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1.0,)) 489*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (["a", "b", "c"],)) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_tuple_rhs_noncontained_type(self): 492*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, List[str]]) -> str: 493*da0073e9SAndroid Build Coastguard Worker if isinstance(x, (int, float)): 494*da0073e9SAndroid Build Coastguard Worker y = x + x 495*da0073e9SAndroid Build Coastguard Worker return str(y) 496*da0073e9SAndroid Build Coastguard Worker else: 497*da0073e9SAndroid Build Coastguard Worker if len(x): 498*da0073e9SAndroid Build Coastguard Worker return x[0] 499*da0073e9SAndroid Build Coastguard Worker else: 500*da0073e9SAndroid Build Coastguard Worker return "bar" 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 503*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (["a", "b", "c"],)) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_tuple_rhs_union(self): 506*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 507*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 508*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, (Union[int, str], float)): 509*da0073e9SAndroid Build Coastguard Worker y = x + x 510*da0073e9SAndroid Build Coastguard Worker return str(y) 511*da0073e9SAndroid Build Coastguard Worker else: 512*da0073e9SAndroid Build Coastguard Worker return "foo" 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker # TODO: There's currently an unrelated bug in 515*da0073e9SAndroid Build Coastguard Worker # `torch.jit.isinstance` that makes it fail for tuple literals. 516*da0073e9SAndroid Build Coastguard Worker # Posted here: https://github.com/pytorch/pytorch/issues/60095 517*da0073e9SAndroid Build Coastguard Worker # Change `assertEqual` to `checkScript` when the bug is fixed 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(1), "2") 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_statically_false(self): 521*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 522*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 523*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, (Union[str, float], List[str], str)): 524*da0073e9SAndroid Build Coastguard Worker z = x + "foo" 525*da0073e9SAndroid Build Coastguard Worker return z 526*da0073e9SAndroid Build Coastguard Worker else: 527*da0073e9SAndroid Build Coastguard Worker return "bar" 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker s = fn.graph 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker # Check that we don't have any branching statements 532*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("block0()").check_not("block1()").run(s) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_statically_true(self): 535*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 536*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], int]) -> Union[List[int], int]: 537*da0073e9SAndroid Build Coastguard Worker if not torch.jit.isinstance(x, (int, List[int])): 538*da0073e9SAndroid Build Coastguard Worker return x 539*da0073e9SAndroid Build Coastguard Worker else: 540*da0073e9SAndroid Build Coastguard Worker l = [1, 2, 3] 541*da0073e9SAndroid Build Coastguard Worker y: Union[List[int], int] = l 542*da0073e9SAndroid Build Coastguard Worker return y 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker s = fn.graph 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker # Check that we don't have any branching statements 547*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("block0()").check_not("block1()").run(s) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): 550*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], int]) -> int: 551*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, (int, float, str)): 552*da0073e9SAndroid Build Coastguard Worker # We should know that `x` is an `int` here 553*da0073e9SAndroid Build Coastguard Worker z = x + 1 554*da0073e9SAndroid Build Coastguard Worker return z 555*da0073e9SAndroid Build Coastguard Worker else: 556*da0073e9SAndroid Build Coastguard Worker return 100 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 559*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_partial_static_refinement_union_rhs(self): 562*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], int]) -> int: 563*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, Union[int, float, str]): 564*da0073e9SAndroid Build Coastguard Worker # We should know that `x` is an `int` here 565*da0073e9SAndroid Build Coastguard Worker z = x + 1 566*da0073e9SAndroid Build Coastguard Worker return z 567*da0073e9SAndroid Build Coastguard Worker else: 568*da0073e9SAndroid Build Coastguard Worker return 100 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 571*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def test_union_type_refinement_internal_declaration(self): 574*da0073e9SAndroid Build Coastguard Worker def fn(flag: bool) -> str: 575*da0073e9SAndroid Build Coastguard Worker x: Union[int, str, None] = None 576*da0073e9SAndroid Build Coastguard Worker if flag: 577*da0073e9SAndroid Build Coastguard Worker y = "foo" 578*da0073e9SAndroid Build Coastguard Worker else: 579*da0073e9SAndroid Build Coastguard Worker y = 1 580*da0073e9SAndroid Build Coastguard Worker if isinstance(x, str): 581*da0073e9SAndroid Build Coastguard Worker return x 582*da0073e9SAndroid Build Coastguard Worker else: 583*da0073e9SAndroid Build Coastguard Worker return "bar" 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (True,)) 586*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (False,)) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_union_branching_with_union_return_and_homogenous_types(self): 589*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> Union[int, str]: 590*da0073e9SAndroid Build Coastguard Worker if x % 2: 591*da0073e9SAndroid Build Coastguard Worker return "foo" 592*da0073e9SAndroid Build Coastguard Worker else: 593*da0073e9SAndroid Build Coastguard Worker return "bar" 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 596*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (8,)) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker def test_union_branching_does_not_autoinfer_undeclared_union(self): 599*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 600*da0073e9SAndroid Build Coastguard Worker if x % 2: 601*da0073e9SAndroid Build Coastguard Worker y = "foo" 602*da0073e9SAndroid Build Coastguard Worker else: 603*da0073e9SAndroid Build Coastguard Worker y = x 604*da0073e9SAndroid Build Coastguard Worker if isinstance(y, str): 605*da0073e9SAndroid Build Coastguard Worker return y 606*da0073e9SAndroid Build Coastguard Worker else: 607*da0073e9SAndroid Build Coastguard Worker return "bar" 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 610*da0073e9SAndroid Build Coastguard Worker RuntimeError, 611*da0073e9SAndroid Build Coastguard Worker "y is set to type str" 612*da0073e9SAndroid Build Coastguard Worker " in the true branch and type int " 613*da0073e9SAndroid Build Coastguard Worker "in the false branch", 614*da0073e9SAndroid Build Coastguard Worker ): 615*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker def test_union_branching_does_not_widen_existing_inferred_type(self): 618*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 619*da0073e9SAndroid Build Coastguard Worker y = "foo" 620*da0073e9SAndroid Build Coastguard Worker if x % 2: 621*da0073e9SAndroid Build Coastguard Worker y = "bar" 622*da0073e9SAndroid Build Coastguard Worker else: 623*da0073e9SAndroid Build Coastguard Worker y = x 624*da0073e9SAndroid Build Coastguard Worker if isinstance(y, str): 625*da0073e9SAndroid Build Coastguard Worker return y 626*da0073e9SAndroid Build Coastguard Worker else: 627*da0073e9SAndroid Build Coastguard Worker return "baz" 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 630*da0073e9SAndroid Build Coastguard Worker RuntimeError, 631*da0073e9SAndroid Build Coastguard Worker "previously had type " 632*da0073e9SAndroid Build Coastguard Worker "str but is now being assigned to a" 633*da0073e9SAndroid Build Coastguard Worker " value of type int", 634*da0073e9SAndroid Build Coastguard Worker ): 635*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def test_union_schema_matching_on_internal_type(self): 638*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], Dict[str, int]]) -> int: 639*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, List[int]): 640*da0073e9SAndroid Build Coastguard Worker return x[0] 641*da0073e9SAndroid Build Coastguard Worker else: 642*da0073e9SAndroid Build Coastguard Worker return list(x.values())[0] 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 645*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker def test_union_subtractive_refinement(self): 648*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], int]) -> int: 649*da0073e9SAndroid Build Coastguard Worker if not isinstance(x, int): 650*da0073e9SAndroid Build Coastguard Worker x.append(1) 651*da0073e9SAndroid Build Coastguard Worker return x[0] 652*da0073e9SAndroid Build Coastguard Worker else: 653*da0073e9SAndroid Build Coastguard Worker return x 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 656*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker def test_union_subtractive_refinement_with_container(self): 659*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[List[int], int]) -> int: 660*da0073e9SAndroid Build Coastguard Worker if not torch.jit.isinstance(x, List[int]): 661*da0073e9SAndroid Build Coastguard Worker return x 662*da0073e9SAndroid Build Coastguard Worker else: 663*da0073e9SAndroid Build Coastguard Worker x.append(1) 664*da0073e9SAndroid Build Coastguard Worker return x[0] 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 667*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3],)) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_union_memory_aliasing(self): 670*da0073e9SAndroid Build Coastguard Worker def fn(): 671*da0073e9SAndroid Build Coastguard Worker x: List[torch.Tensor] = [] 672*da0073e9SAndroid Build Coastguard Worker z: List[Optional[List[torch.Tensor]]] = [] 673*da0073e9SAndroid Build Coastguard Worker z.append(x) 674*da0073e9SAndroid Build Coastguard Worker x_alias = z[0] 675*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x_alias, List[torch.Tensor]): 676*da0073e9SAndroid Build Coastguard Worker x_alias.append(torch.tensor(3)) 677*da0073e9SAndroid Build Coastguard Worker return x 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker def test_union_serialization_preserves_type_annotations(self): 682*da0073e9SAndroid Build Coastguard Worker # This function will fail after being torch.jit.save'd and 683*da0073e9SAndroid Build Coastguard Worker # torch.jit.load'd if the type annotations aren't preserved 684*da0073e9SAndroid Build Coastguard Worker # for Union during serialization. We need the `Union[str, int]` 685*da0073e9SAndroid Build Coastguard Worker # annotation to make sure that `y` is typed as a Union instead 686*da0073e9SAndroid Build Coastguard Worker # of as a str in one branch and an int in the other 687*da0073e9SAndroid Build Coastguard Worker def fn(x: int) -> str: 688*da0073e9SAndroid Build Coastguard Worker if x % 2: 689*da0073e9SAndroid Build Coastguard Worker y: Union[str, int] = "bar" 690*da0073e9SAndroid Build Coastguard Worker else: 691*da0073e9SAndroid Build Coastguard Worker y: Union[str, int] = x 692*da0073e9SAndroid Build Coastguard Worker if isinstance(y, str): 693*da0073e9SAndroid Build Coastguard Worker return y 694*da0073e9SAndroid Build Coastguard Worker else: 695*da0073e9SAndroid Build Coastguard Worker return "baz" 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 698*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (8,)) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker def _assert_passes(self, template: str, ann: str, lhs: str): 701*da0073e9SAndroid Build Coastguard Worker code = template.format(ann=ann, lhs=lhs) 702*da0073e9SAndroid Build Coastguard Worker self.checkScript(code, (), name="fn") 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker def _assert_raises(self, template: str, ann: str, lhs: str, msg: str): 705*da0073e9SAndroid Build Coastguard Worker code = template.format(ann=ann, lhs=lhs) 706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 707*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code, _frames_up=1) 708*da0073e9SAndroid Build Coastguard Worker string_frontend = getattr(cu, "fn") # noqa: B009 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker def test_union_with_list_assignment(self): 711*da0073e9SAndroid Build Coastguard Worker template = dedent( 712*da0073e9SAndroid Build Coastguard Worker """ 713*da0073e9SAndroid Build Coastguard Worker def fn(): 714*da0073e9SAndroid Build Coastguard Worker x: {ann} = {lhs} 715*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, List[torch.Tensor]): 716*da0073e9SAndroid Build Coastguard Worker x.append(torch.tensor(3)) 717*da0073e9SAndroid Build Coastguard Worker return x 718*da0073e9SAndroid Build Coastguard Worker """ 719*da0073e9SAndroid Build Coastguard Worker ) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker lhs = { 722*da0073e9SAndroid Build Coastguard Worker "list_literal_empty": "[]", 723*da0073e9SAndroid Build Coastguard Worker "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]", 724*da0073e9SAndroid Build Coastguard Worker "list_literal_of_str": '["foo", "bar", "baz"]', 725*da0073e9SAndroid Build Coastguard Worker "list_literal_of_mixed": "[torch.arange(5), 1]", 726*da0073e9SAndroid Build Coastguard Worker "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]", 727*da0073e9SAndroid Build Coastguard Worker "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]', 728*da0073e9SAndroid Build Coastguard Worker "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]", 729*da0073e9SAndroid Build Coastguard Worker } 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker """ 732*da0073e9SAndroid Build Coastguard Worker Union[List[str], List[torch.Tensor]] 733*da0073e9SAndroid Build Coastguard Worker """ 734*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 735*da0073e9SAndroid Build Coastguard Worker template, 736*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 737*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_empty"], 738*da0073e9SAndroid Build Coastguard Worker "there are multiple possible List type " 739*da0073e9SAndroid Build Coastguard Worker "candidates in the Union annotation", 740*da0073e9SAndroid Build Coastguard Worker ) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 743*da0073e9SAndroid Build Coastguard Worker template, 744*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 745*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_of_tensor"], 746*da0073e9SAndroid Build Coastguard Worker ) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 749*da0073e9SAndroid Build Coastguard Worker template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"] 750*da0073e9SAndroid Build Coastguard Worker ) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 753*da0073e9SAndroid Build Coastguard Worker template, 754*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 755*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_of_mixed"], 756*da0073e9SAndroid Build Coastguard Worker "none of those types match the types of the" " given list elements", 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 760*da0073e9SAndroid Build Coastguard Worker template, 761*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 762*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_tensor"], 763*da0073e9SAndroid Build Coastguard Worker ) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 766*da0073e9SAndroid Build Coastguard Worker template, 767*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 768*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_str"], 769*da0073e9SAndroid Build Coastguard Worker ) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker # TODO: Support mixed list comprehensions 772*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 773*da0073e9SAndroid Build Coastguard Worker template, 774*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 775*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_mixed"], 776*da0073e9SAndroid Build Coastguard Worker "Arguments for call are not valid", 777*da0073e9SAndroid Build Coastguard Worker ) 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker """ 780*da0073e9SAndroid Build Coastguard Worker Union[int, torch.Tensor] 781*da0073e9SAndroid Build Coastguard Worker """ 782*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 783*da0073e9SAndroid Build Coastguard Worker template, 784*da0073e9SAndroid Build Coastguard Worker "Union[int, torch.Tensor]", 785*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_empty"], 786*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with an " "inner List type", 787*da0073e9SAndroid Build Coastguard Worker ) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 790*da0073e9SAndroid Build Coastguard Worker template, 791*da0073e9SAndroid Build Coastguard Worker "Union[int, torch.Tensor]", 792*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_of_tensor"], 793*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with an " "inner List type", 794*da0073e9SAndroid Build Coastguard Worker ) 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 797*da0073e9SAndroid Build Coastguard Worker template, 798*da0073e9SAndroid Build Coastguard Worker "Union[int, torch.Tensor]", 799*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_tensor"], 800*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with an " "inner List type", 801*da0073e9SAndroid Build Coastguard Worker ) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker """ 804*da0073e9SAndroid Build Coastguard Worker Union[List[torch.Tensor], int] 805*da0073e9SAndroid Build Coastguard Worker """ 806*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 807*da0073e9SAndroid Build Coastguard Worker template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"] 808*da0073e9SAndroid Build Coastguard Worker ) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 811*da0073e9SAndroid Build Coastguard Worker template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"] 812*da0073e9SAndroid Build Coastguard Worker ) 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 815*da0073e9SAndroid Build Coastguard Worker template, 816*da0073e9SAndroid Build Coastguard Worker "Union[List[torch.Tensor], int]", 817*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_of_str"], 818*da0073e9SAndroid Build Coastguard Worker r"List type annotation `List\[Tensor\]` did " 819*da0073e9SAndroid Build Coastguard Worker "not match the types of the given list " 820*da0073e9SAndroid Build Coastguard Worker "elements", 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 824*da0073e9SAndroid Build Coastguard Worker template, 825*da0073e9SAndroid Build Coastguard Worker "Union[List[torch.Tensor], int]", 826*da0073e9SAndroid Build Coastguard Worker lhs["list_literal_of_mixed"], 827*da0073e9SAndroid Build Coastguard Worker r"List type annotation `List\[Tensor\]` did " 828*da0073e9SAndroid Build Coastguard Worker "not match the types of the given list " 829*da0073e9SAndroid Build Coastguard Worker "elements", 830*da0073e9SAndroid Build Coastguard Worker ) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 833*da0073e9SAndroid Build Coastguard Worker template, 834*da0073e9SAndroid Build Coastguard Worker "Union[List[torch.Tensor], int]", 835*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_tensor"], 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 839*da0073e9SAndroid Build Coastguard Worker template, 840*da0073e9SAndroid Build Coastguard Worker "Union[List[torch.Tensor], int]", 841*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_str"], 842*da0073e9SAndroid Build Coastguard Worker r"List type annotation `List\[Tensor\]` did " 843*da0073e9SAndroid Build Coastguard Worker "not match the types of the given list " 844*da0073e9SAndroid Build Coastguard Worker "elements", 845*da0073e9SAndroid Build Coastguard Worker ) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker # TODO(@ansley): Support mixed list comprehensions 848*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 849*da0073e9SAndroid Build Coastguard Worker template, 850*da0073e9SAndroid Build Coastguard Worker "Union[List[torch.Tensor], int]", 851*da0073e9SAndroid Build Coastguard Worker lhs["list_comprehension_of_mixed"], 852*da0073e9SAndroid Build Coastguard Worker "Arguments for call are not valid", 853*da0073e9SAndroid Build Coastguard Worker ) 854*da0073e9SAndroid Build Coastguard Worker 855*da0073e9SAndroid Build Coastguard Worker def test_union_with_dict_assignment(self): 856*da0073e9SAndroid Build Coastguard Worker template = dedent( 857*da0073e9SAndroid Build Coastguard Worker """ 858*da0073e9SAndroid Build Coastguard Worker def fn(): 859*da0073e9SAndroid Build Coastguard Worker x: {ann} = {lhs} 860*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, Dict[str, torch.Tensor]): 861*da0073e9SAndroid Build Coastguard Worker x["foo"] = torch.tensor(3) 862*da0073e9SAndroid Build Coastguard Worker return x 863*da0073e9SAndroid Build Coastguard Worker """ 864*da0073e9SAndroid Build Coastguard Worker ) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker lhs = { 867*da0073e9SAndroid Build Coastguard Worker "dict_literal_empty": "{}", 868*da0073e9SAndroid Build Coastguard Worker "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}', 869*da0073e9SAndroid Build Coastguard Worker "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}', 870*da0073e9SAndroid Build Coastguard Worker "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}', 871*da0073e9SAndroid Build Coastguard Worker "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \ 872*da0073e9SAndroid Build Coastguard Worker zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}', 873*da0073e9SAndroid Build Coastguard Worker "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \ 874*da0073e9SAndroid Build Coastguard Worker zip(["foo", "bar"], [1, 2]}', 875*da0073e9SAndroid Build Coastguard Worker "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \ 876*da0073e9SAndroid Build Coastguard Worker zip(["foo", "bar"], [torch.arange(3), 2])}', 877*da0073e9SAndroid Build Coastguard Worker "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))", 878*da0073e9SAndroid Build Coastguard Worker "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])', 879*da0073e9SAndroid Build Coastguard Worker "dict_keyword_with_empty_iterable": "dict([])", 880*da0073e9SAndroid Build Coastguard Worker "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])', 881*da0073e9SAndroid Build Coastguard Worker "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})', 882*da0073e9SAndroid Build Coastguard Worker "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))', 883*da0073e9SAndroid Build Coastguard Worker } 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker """ 886*da0073e9SAndroid Build Coastguard Worker Union[Dict[str, torch.Tensor], Dict[str, int]] 887*da0073e9SAndroid Build Coastguard Worker """ 888*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 889*da0073e9SAndroid Build Coastguard Worker template, 890*da0073e9SAndroid Build Coastguard Worker "Union[List[str], List[torch.Tensor]]", 891*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_empty"], 892*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with an " "inner Dict type", 893*da0073e9SAndroid Build Coastguard Worker ) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 896*da0073e9SAndroid Build Coastguard Worker template, 897*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 898*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_str_tensor"], 899*da0073e9SAndroid Build Coastguard Worker ) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 902*da0073e9SAndroid Build Coastguard Worker template, 903*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 904*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_str_int"], 905*da0073e9SAndroid Build Coastguard Worker ) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 908*da0073e9SAndroid Build Coastguard Worker template, 909*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 910*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_mixed"], 911*da0073e9SAndroid Build Coastguard Worker "none of those dict types can hold the " 912*da0073e9SAndroid Build Coastguard Worker "types of the given keys and values", 913*da0073e9SAndroid Build Coastguard Worker ) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker # TODO: String frontend does not support tuple unpacking 916*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/64096 917*da0073e9SAndroid Build Coastguard Worker # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 918*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_str_tensor"]) 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 921*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_str_int"]) 922*da0073e9SAndroid Build Coastguard Worker 923*da0073e9SAndroid Build Coastguard Worker # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", 924*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_mixed"], 925*da0073e9SAndroid Build Coastguard Worker # "foobar") 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker # self._assert_passes(template, 928*da0073e9SAndroid Build Coastguard Worker # "Union[Dict[str, torch.Tensor], Dict[str, int]]", 929*da0073e9SAndroid Build Coastguard Worker # lhs["dict_keyword_with_internal_aggregate_function"]) 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker # TODO(@ansley): Follow-up project needed for full type 932*da0073e9SAndroid Build Coastguard Worker # inference with dict keyword (supported for dict comprehension 933*da0073e9SAndroid Build Coastguard Worker # and dict literal already; should not be a blocker for anyone) 934*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 935*da0073e9SAndroid Build Coastguard Worker template, 936*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 937*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword"], 938*da0073e9SAndroid Build Coastguard Worker "full type inference is not yet supported", 939*da0073e9SAndroid Build Coastguard Worker ) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 942*da0073e9SAndroid Build Coastguard Worker template, 943*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 944*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_iterable"], 945*da0073e9SAndroid Build Coastguard Worker "full type inference is not yet supported", 946*da0073e9SAndroid Build Coastguard Worker ) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 949*da0073e9SAndroid Build Coastguard Worker template, 950*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 951*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_empty_iterable"], 952*da0073e9SAndroid Build Coastguard Worker "full type inference is not yet supported", 953*da0073e9SAndroid Build Coastguard Worker ) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 956*da0073e9SAndroid Build Coastguard Worker template, 957*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 958*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_mapping"], 959*da0073e9SAndroid Build Coastguard Worker "full type inference is not yet supported", 960*da0073e9SAndroid Build Coastguard Worker ) 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 963*da0073e9SAndroid Build Coastguard Worker template, 964*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], Dict[str, int]]", 965*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_mapping_and_kwargs"], 966*da0073e9SAndroid Build Coastguard Worker "full type inference is not yet supported", 967*da0073e9SAndroid Build Coastguard Worker ) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker """ 970*da0073e9SAndroid Build Coastguard Worker Union[int, torch.Tensor] 971*da0073e9SAndroid Build Coastguard Worker """ 972*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 973*da0073e9SAndroid Build Coastguard Worker template, 974*da0073e9SAndroid Build Coastguard Worker "Union[int, torch.Tensor]", 975*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_empty"], 976*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with " "an inner Dict type", 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 980*da0073e9SAndroid Build Coastguard Worker template, 981*da0073e9SAndroid Build Coastguard Worker "Union[int, torch.Tensor]", 982*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_str_tensor"], 983*da0073e9SAndroid Build Coastguard Worker "Expected an Union type annotation with " "an inner Dict type", 984*da0073e9SAndroid Build Coastguard Worker ) 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker # See above--string frontend does not support tuple unpacking 987*da0073e9SAndroid Build Coastguard Worker # self._assert_raises(template, "Union[int, torch.Tensor]", 988*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_tensor"], 989*da0073e9SAndroid Build Coastguard Worker # "foobar") 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker """ 992*da0073e9SAndroid Build Coastguard Worker Union[Dict[str, torch.Tensor], int] 993*da0073e9SAndroid Build Coastguard Worker """ 994*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 995*da0073e9SAndroid Build Coastguard Worker template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"] 996*da0073e9SAndroid Build Coastguard Worker ) 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 999*da0073e9SAndroid Build Coastguard Worker template, 1000*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1001*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_str_tensor"], 1002*da0073e9SAndroid Build Coastguard Worker ) 1003*da0073e9SAndroid Build Coastguard Worker 1004*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 1005*da0073e9SAndroid Build Coastguard Worker template, 1006*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1007*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_str_int"], 1008*da0073e9SAndroid Build Coastguard Worker "Type annotation was inferred to be " 1009*da0073e9SAndroid Build Coastguard Worker r"`Dict\[str, Tensor\]`, but the type of " 1010*da0073e9SAndroid Build Coastguard Worker "values given by the dict literal is", 1011*da0073e9SAndroid Build Coastguard Worker ) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker self._assert_raises( 1014*da0073e9SAndroid Build Coastguard Worker template, 1015*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1016*da0073e9SAndroid Build Coastguard Worker lhs["dict_literal_of_mixed"], 1017*da0073e9SAndroid Build Coastguard Worker "Type annotation was inferred to be " 1018*da0073e9SAndroid Build Coastguard Worker r"`Dict\[str, Tensor\]`, but the type of " 1019*da0073e9SAndroid Build Coastguard Worker "values given by the dict literal is", 1020*da0073e9SAndroid Build Coastguard Worker ) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 1023*da0073e9SAndroid Build Coastguard Worker template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"] 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 1027*da0073e9SAndroid Build Coastguard Worker template, 1028*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1029*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_iterable"], 1030*da0073e9SAndroid Build Coastguard Worker ) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 1033*da0073e9SAndroid Build Coastguard Worker template, 1034*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1035*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_empty_iterable"], 1036*da0073e9SAndroid Build Coastguard Worker ) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 1039*da0073e9SAndroid Build Coastguard Worker template, 1040*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1041*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_mapping"], 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker self._assert_passes( 1045*da0073e9SAndroid Build Coastguard Worker template, 1046*da0073e9SAndroid Build Coastguard Worker "Union[Dict[str, torch.Tensor], int]", 1047*da0073e9SAndroid Build Coastguard Worker lhs["dict_keyword_with_mapping_and_kwargs"], 1048*da0073e9SAndroid Build Coastguard Worker ) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker # See above--string frontend does not support tuple unpacking 1051*da0073e9SAndroid Build Coastguard Worker # self._assert_passes(template, 1052*da0073e9SAndroid Build Coastguard Worker # "Union[Dict[str, torch.Tensor], int]", 1053*da0073e9SAndroid Build Coastguard Worker # lhs["dict_keyword_with_internal_aggregate_function"]) 1054*da0073e9SAndroid Build Coastguard Worker # 1055*da0073e9SAndroid Build Coastguard Worker # self._assert_passes(template, 1056*da0073e9SAndroid Build Coastguard Worker # "Union[Dict[str, torch.Tensor], int]", 1057*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_str_tensor"]) 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker # self._assert_raises(template, 1060*da0073e9SAndroid Build Coastguard Worker # "Union[Dict[str, torch.Tensor], int]", 1061*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_str_int"], 1062*da0073e9SAndroid Build Coastguard Worker # "foobar") 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker # self._assert_raises(template, 1065*da0073e9SAndroid Build Coastguard Worker # "Union[Dict[str, torch.Tensor], int]", 1066*da0073e9SAndroid Build Coastguard Worker # lhs["dict_comprehension_of_mixed"], 1067*da0073e9SAndroid Build Coastguard Worker # "foobar") 1068