xref: /aosp_15_r20/external/pytorch/test/jit/test_typing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from collections import namedtuple
6from typing import Dict, List, NamedTuple, Tuple
7
8import torch
9from torch.testing._internal.common_utils import IS_WINDOWS
10from torch.testing._internal.jit_utils import JitTestCase, make_global
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestTyping(JitTestCase):
26    def test_dict_in_not_in(self):
27        def test_in_dict(x):
28            # type: (Dict[str, int]) -> bool
29            return "hi" in x
30
31        self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},))
32        self.checkScript(test_in_dict, ({"bye": 3},))
33
34        # Check evaluation order
35        @torch.jit.script
36        def a():
37            print("a")
38            return 3
39
40        @torch.jit.script
41        def b():
42            print("b")
43            return {3: 2, 4: 1}
44
45        @torch.jit.script
46        def fn():
47            return a() in b()
48
49        with self.capture_stdout() as captured:
50            self.assertTrue(fn())
51        if not IS_WINDOWS:
52            # no stdout capturing on windows
53            self.assertEqual(captured[0], "a\nb\n")
54
55        def test_not_in_dict(a):
56            # type: (Dict[str, int]) -> bool
57            if "hello" not in a:
58                return False
59            else:
60                return True
61
62        self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},))
63        self.checkScript(test_not_in_dict, ({"world": 2},))
64
65        def test_dict_tensor_key(a, t):
66            # type: (Dict[Tensor, int], Tensor) -> bool
67            if t in a:
68                return True
69            else:
70                return False
71
72        inp1 = torch.tensor(3)
73        inp2 = torch.tensor(5)
74        dict_a = {inp1: 1, inp2: 3}
75        self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4)))
76        self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3)))
77        self.checkScript(test_dict_tensor_key, (dict_a, inp1))
78        self.checkScript(test_dict_tensor_key, (dict_a, inp2))
79
80    def test_list_type_refinement_annotation_element_mismatch(self):
81        def fn():
82            l: List[int] = [1, 2, "foo", 3]
83            return l
84
85        with self.assertRaisesRegex(
86            RuntimeError,
87            "List type annotation"
88            r" `List\[int\]` did not match the "
89            "types of the given list elements",
90        ):
91            torch.jit.script(fn)
92
93    def test_dict_type_refinement_annotation_key_mismatch(self):
94        def fn():
95            l1 = [1, 2, "foo", 3]
96            l2 = ["foo", "bar", "baz", "qux"]
97            d: Dict[int, str] = dict(zip(l1, l2))
98            return d
99
100        with self.assertRaisesRegex(
101            RuntimeError,
102            "Dicts may only "
103            "contain homogeneous keys, but the "
104            "type of the first generated key "
105            r"was Union\[int, str\]",
106        ):
107            torch.jit.script(fn)
108
109    def test_dict_type_refinement_annotation_value_mismatch(self):
110        def fn():
111            l1 = ["foo", "bar", "baz", "qux"]
112            l2 = [1, 2, "foo", 3]
113            d: Dict[str, int] = dict(zip(l1, l2))
114            return d
115
116        with self.assertRaisesRegex(
117            RuntimeError,
118            "Dict type annotation"
119            r" `Dict\[str, int\]` did not match"
120            " the type of an actual value type"
121            r" `Union\[int, str\]`",
122        ):
123            torch.jit.script(fn)
124
125    def test_dict_invalid_annotations(self):
126        # Check for invalid value type annotation
127        def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
128            return
129
130        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
131            torch.jit.script(wrong_value_type)
132
133        # Check for invalid key type annotation
134        def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
135            return
136
137        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
138            torch.jit.script(wrong_key_type)
139
140        # Check for invalid key and value type annotation
141        def wrong_key_value_type(
142            dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
143        ):
144            return
145
146        with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
147            torch.jit.script(wrong_key_value_type)
148
149    def test_tuple_specialization(self):
150        @torch.jit.script
151        def f(t, s):
152            # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
153            x, t2 = t
154            _, y = t2
155            return x + y
156
157        t = (
158            torch.randn(2, 2),
159            (1, torch.randn(2, 2)),
160        )
161        f(t, "hi")
162        graph = f.graph_for(t, "hi")
163        input_types = list(next(graph.inputs()).type().elements())
164        w = input_types[0]
165        self.assertEqual(input_types[0].kind(), "TensorType")
166        self.assertEqual(input_types[1].elements()[1].kind(), "TensorType")
167
168    def test_tuple_io(self):
169        def stuff(x):
170            # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
171            a, b = x
172            return b, a
173
174        a = (torch.rand(3), torch.rand(3))
175        self.checkScript(stuff, (a,))
176
177    def test_tuple_keyword(self):
178        def bar():
179            f = tuple((1, 2))  # noqa: C409
180            return f
181
182        self.checkScript(bar, ())
183
184        def foo():
185            return tuple(1, 2)
186
187        self.checkScriptRaisesRegex(foo, (), Exception, "1 argument")
188
189        def cant_infer_size():
190            return tuple([1, 2, 3])  # noqa: C409
191
192        with self.assertRaisesRegex(Exception, "cannot statically infer the expected"):
193            torch.jit.script(cant_infer_size)
194
195    def test_tuple_create_return(self):
196        def stuff2(x):
197            # type: (int) -> Tuple[Tensor, Tensor]
198            a = (torch.ones(x), torch.zeros(x))
199            return a
200
201        self.checkScript(stuff2, (3,))
202
203    def test_list_io(self):
204        def stuff3(x):
205            # type: (List[int]) -> Tuple[Tensor, List[int]]
206            return torch.ones(x), x
207
208        self.checkScript(stuff3, ([3, 2],))
209
210    def test_bool_list_io(self):
211        @torch.jit.script
212        def stuff4(x):
213            # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
214            return x, [True, False], [[True]]
215
216        li_1, li_2, li_3 = stuff4([True])
217        li_3 = li_3[0]
218        for li in [li_1, li_2, li_3]:
219            self.assertTrue(type(li[0]) == bool)
220
221    def test_nested_list(self):
222        def foo(z):
223            # type: (Tuple[int, List[List[int]]]) -> int
224            x, y = z
225            return y[0][1]
226
227        self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
228
229    def test_list_sum(self):
230        def fn(x: List[int]) -> int:
231            return sum(x)
232
233        def fn1(x: List[float]):
234            return sum(x)
235
236        def fn2(x: List[bool]):
237            return sum(x)
238
239        self.checkScript(fn, ([1, 2, 3],))
240        self.checkScript(fn1, ([1.0, 2.0, 3.0],))
241        self.checkScript(fn1, ([1, 2.8, 3],))
242        self.checkScript(fn2, ([True, False, False],))
243        self.checkScript(fn2, ([False, False, False],))
244        self.checkScript(fn2, ([0, 1, 1, 0],))
245
246    def test_list_unification(self):
247        def fn():
248            return [1, None, 2]
249
250        def fn2(x):
251            return [torch.ones(2, 2), None, x]
252
253        self.checkScript(fn, [])
254        self.checkScript(fn2, (torch.ones(2, 2),))
255
256    # to avoid defining sum_list in multiple tests
257    def get_sum_list_fn(self):
258        def sum_list(a):
259            # type: (List[int]) -> int
260            sum = 0
261            for i in a:
262                sum += i
263
264            return sum
265
266        return sum_list
267
268    def test_sum_list_diff_elms(self):
269        self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
270
271    def test_sum_list_empty(self):
272        self.checkScript(self.get_sum_list_fn(), ([],))
273
274    def test_sum_list_one(self):
275        self.checkScript(self.get_sum_list_fn(), ([1],))
276
277    def test_sum_list_literal(self):
278        def sum_list():
279            # type: () -> int
280            sum = 0
281            for i in [1, 2, 3, 4, 5]:
282                sum += i
283
284            return sum
285
286        self.checkScript(sum_list, ())
287
288    def test_sum_list_wrong_type(self):
289        with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
290
291            @torch.jit.script
292            def sum_list(a):
293                # type: (int) -> int
294                sum = 0
295                for i in a:  # noqa: T484
296                    sum += i
297
298                return sum
299
300            sum_list(1)
301
302    def test_list_iterables(self):
303        with self.assertRaisesRegex(
304            RuntimeError, "List of iterables is not supported currently"
305        ):
306            cu = torch.jit.CompilationUnit(
307                """
308            def list_iterables(x):
309                for i, j in [2, 3, 4], [5, 6, 7]:
310                    x += i
311                    x += j
312                return x
313            """
314            )
315
316    def test_for_in_string(self):
317        def test_strings(x):
318            # type: (str) -> str
319            reverse = ""
320            for c in x:
321                reverse = c + reverse
322            return reverse
323
324        self.checkScript(test_strings, ("hello",))
325        self.checkScript(test_strings, ("",))
326
327        def test_list_strings(x):
328            # type: (List[str]) -> str
329            result = ""
330            for sub_str in x:
331                result += sub_str
332            return result
333
334        self.checkScript(test_list_strings, (["hello", "world"],))
335        self.checkScript(test_list_strings, (["hello", " ", "world", ""],))
336
337    def test_for_in_dict(self):
338        def test_dicts(x):
339            # type: (Dict[str, int]) -> int
340            sum = 0
341            for key in x:
342                sum += x[key]
343            return sum
344
345        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
346
347        def test_dict_keys_values(x):
348            # type: (Dict[str, int]) -> Tuple[str, int]
349            key_str = ""
350            sum = 0
351            for key in x.keys():
352                key_str += key
353            for val in x.values():
354                sum += val
355            return key_str, sum
356
357        self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},))
358
359    def test_for_tuple_unpack(self):
360        def for_tuple_unpack(x, y):
361            for i, j in [[3, 4], [5, 6], [7, 8]]:
362                x += i
363                y += j
364            return x, y
365
366        self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5)))
367
368        def nested_tuple_unpack(x, y):
369            # type: (List[int], List[int]) -> int
370            sum = 0
371            for i, (j, k), v in zip(x, enumerate(x), y):
372                sum += i + j + k + v
373            return sum
374
375        self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6]))
376
377    def test_dict_comprehension(self):
378        def fn():
379            return {i: chr(i + 65) for i in range(4)}
380
381        self.checkScript(fn, ())
382
383    def test_dict_comprehension_with_type_annotation(self):
384        def fn():
385            d: Dict[int, str] = {i: chr(i + 65) for i in range(4)}
386            return d
387
388        self.checkScript(fn, ())
389
390        with self.assertRaisesRegex(RuntimeError, ""):
391            with self.assertRaisesRegex(
392                AssertionError,
393                "Expected Dict "
394                "type annotation for dict "
395                "comprehension, found "
396                "Tuple[int, str]",
397            ):
398
399                @torch.jit.script
400                def fn():
401                    d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)}
402                    return d
403
404    def test_dict_comprehension_scope(self):
405        def comprehension_can_access_outer_scope_variables():
406            lst = ["foo", "bar", "baz"]
407            return {l: len(l) for l in lst}
408
409        self.checkScript(comprehension_can_access_outer_scope_variables, ())
410
411        with self.assertRaisesRegex(RuntimeError, "undefined value i"):
412
413            @torch.jit.script
414            def outer_scope_cannot_access_comprehension_variables():
415                d = {i: chr(i + 65) for i in range(4)}
416                i = i + 1  # noqa: F821
417
418    def test_for_tuple_assign(self):
419        def test_simple_assign(x):
420            # type: (Tuple[int, float]) -> float
421            sum = 0.0
422            for a in x:
423                sum += float(a)
424            return sum
425
426        self.checkScript(test_simple_assign, ((1, 2.5),))
427
428        def test_tuple_assign(x):
429            # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int
430            sum = 0
431            for a in x:
432                sum += a[0]
433                sum += a[1]
434            return sum
435
436        self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),))
437
438        def test_single_starred_lhs(self):
439            with self.assertRaisesRegex(
440                RuntimeError,
441                "A Starred expression may only appear on the lhs within the presence"
442                " of another non-starred expression",
443            ):
444                cu = torch.jit.CompilationUnit(
445                    """
446                def single_starred_lhs(x):
447                    a = (x, x, x)
448                    *b, = a
449                    return b
450                """
451                )
452
453    def test_singleton_tuple_unpack(self):
454        def foo(a):
455            (b,) = (a,)
456            return b + 1
457
458        self.checkScript(foo, (torch.rand(3),))
459
460    def test_tuple_assignments(self):
461        def var_tuple_assign(x, y):
462            # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
463            (a, b), c = x, y
464            return a + b + c
465
466        tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4))
467        self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4)))
468
469        def nested_tuple_assign(x, y, z):
470            # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int
471            a, (b, (c, d)), (e, f) = x, y, z
472            return a + b + c + d + e + f
473
474        self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6))))
475
476        def subscript_tuple_assign(a, x, i):
477            # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int]
478            a[i], (x[i], b) = 1, (2, 3)
479            return a[i] + 1, x + 5, b
480
481        self.checkScript(
482            subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)
483        )
484
485        def star_tuple_assign():
486            # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
487            a, (b, *c), *d = 1, (2, 3, 4), 5, 6
488            return a, b, c, d
489
490        self.checkScript(star_tuple_assign, ())
491
492        def subscript_tuple_augmented_assign(a):
493            # type: (Tuple[int, int]) -> Tuple[int, int]
494            a[0] += 1
495            return a
496
497        with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"):
498            scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
499
500    def test_multiple_assign(self):
501        def test():
502            a = b, c = d, f = (1, 1)
503
504            # side effect
505            ten = torch.tensor(1)
506            ten1 = ten2 = ten.add_(1)
507
508            # ordering
509            x = 1
510            y = 3
511            x, y = y, x + y
512
513            return a, b, c, d, f, ten, ten1, ten2, x, y
514
515        self.checkScript(test, ())
516
517    def test_opt_opt_refinement(self):
518        @torch.jit.script
519        def test_unify(weight, bias):
520            # type: (Optional[int], Optional[int]) -> Optional[int]
521            if weight is not None:
522                opt = None
523            else:
524                if bias is not None:
525                    opt = 1
526                else:
527                    opt = None
528
529            return opt
530
531    def test_optional_refinement(self):
532        @torch.jit.script
533        def test_if_none_assignment(x):
534            # type: (Optional[int]) -> int
535            if x is None:
536                x = 1
537            return x + 1
538
539        self.assertEqual(test_if_none_assignment(1), 2)
540
541    def test_optional_conversion(self):
542        @torch.jit.script
543        def other_fn(x=None):
544            # type: (Optional[int]) -> int
545            return torch.jit._unwrap_optional(x)
546
547        @torch.jit.script
548        def fn(x):
549            # type: (int) -> int
550            return other_fn(x)
551
552        self.assertEqual(fn(2), 2)
553
554        @torch.jit.script
555        def unify_to_optional(x):
556            # type: (bool) -> Optional[int]
557            if x:
558                a = None
559            else:
560                a = 2
561            return a
562
563        self.assertEqual(unify_to_optional(True), None)
564        self.assertEqual(unify_to_optional(False), 2)
565
566        @torch.jit.script
567        def opt_list(x):
568            # type: (Optional[List[float]]) -> int
569            return 2
570
571        @torch.jit.script
572        def broadcast_opt_list(x):
573            # type: (Optional[BroadcastingList2[float]]) -> int
574            return 2
575
576        @torch.jit.script
577        def opt_list_tuple_caller(x):
578            # type: (Tuple[float, float]) -> int
579            return opt_list(x) + broadcast_opt_list(x)
580
581        self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4)
582
583    def test_optional_tuple(self):
584        def fn(x=None):
585            # type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
586            if x is None:
587                new_x = (1, 2)
588            else:
589                new_x = x
590            return new_x
591
592        self.checkScript(fn, ((3, 4),))
593        self.checkScript(fn, ())
594
595    def test_namedtuple_redefine(self):
596        global _1, _2
597        _1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
598        _2 = namedtuple("GoogLeNetOutputs", ["different"])
599
600        with self.assertRaisesRegex(RuntimeError, r"redefine"):
601
602            @torch.jit.script
603            def foo(x, y):
604                # type: (_1, _2) -> _1
605                return x
606
607    def test_namedtuple_py2(self):
608        global _GoogLeNetOutputs  # see [local resolution in python]
609        _GoogLeNetOutputs = namedtuple(
610            "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
611        )
612
613        @torch.jit.script
614        def foo(x):
615            # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
616            return x
617
618        vals = torch.rand(3), torch.rand(4), torch.rand(5)
619        out = foo(
620            _GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])
621        )
622        self.assertEqual(out.logits, vals[0])
623        self.assertEqual(out.aux_logits2, vals[1])
624        self.assertEqual(out.aux_logits1, vals[2])
625
626    def test_namedtuple_good_error(self):
627        global _GoogLeNetOutputs  # see [local resolution in python]
628        _GoogLeNetOutputs = namedtuple(
629            "GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
630        )
631
632        @torch.jit.script
633        def foo(x):
634            # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
635            return x
636
637        with self.assertRaisesRegex(
638            RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)"
639        ):
640            out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
641
642    def test_namedtuple_error_source_attribution(self):
643        class _NamedTupleBadMemberType(NamedTuple):
644            f1: torch.Tensor
645            f2: "ABadForwardRefType"  # noqa: F821
646
647        make_global(_NamedTupleBadMemberType)  # see [local resolution in python]
648
649        def fn(x: _NamedTupleBadMemberType) -> torch.Tensor:
650            return x.f1.relu()
651
652        # assert that this has a location associated with the error.
653        # note the " +" is regex (i.e. "at least one space")
654        with self.assertRaisesRegex(ValueError, "at +File"):
655            torch.jit.script(fn)
656
657    def test_inherited_annotations_python_310(self):
658        # See #104484
659        # In python >=3.10, inspect.get_annotations doesn't always return the same values.
660        # Sometimes it will show all annotations; other times it will show only annotations
661        # that show in that class, not classes it inherits fro.
662        class BaseModule(torch.nn.Module):
663            state: List[int]
664
665            def forward(self, x):
666                pass
667
668        def do_something_with_list(x: List[int]):
669            if x:
670                return x[-1]
671            return 5
672
673        class Submodule(BaseModule):
674            def __init__(self, self_x_value):
675                super().__init__()
676                self.x = self_x_value
677                self.state = []
678
679            def forward(self, x):
680                return self.x + x + do_something_with_list(self.state)
681
682        class LowestModule(Submodule):
683            def __init__(self) -> None:
684                super().__init__(123)
685
686        mod = LowestModule()
687        mod2 = LowestModule()
688        mod_s = torch.jit.script(mod)
689        mod2_s = torch.jit.script(mod2)
690