xref: /aosp_15_r20/external/pytorch/test/jit/test_recursive_script.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import re
5import sys
6import types
7import typing
8import typing_extensions
9from collections import OrderedDict
10from typing import Dict, List, Optional, Tuple
11
12import torch
13import torch.jit.frontend
14import torch.nn as nn
15from torch import Tensor
16from torch.testing import FileCheck
17
18
19# Make the helper files in test/ importable
20pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
21sys.path.append(pytorch_test_dir)
22from torch.testing._internal.jit_utils import (
23    _tmp_donotuse_dont_inline_everything,
24    JitTestCase,
25)
26
27
28if __name__ == "__main__":
29    raise RuntimeError(
30        "This test file is not meant to be run directly, use:\n\n"
31        "\tpython test/test_jit.py TESTNAME\n\n"
32        "instead."
33    )
34
35
36class TestRecursiveScript(JitTestCase):
37    def test_inferred_nonetype(self):
38        class M(nn.Module):
39            def __init__(self) -> None:
40                super().__init__()
41                self.x = None
42
43            def forward(self):
44                assert self.x is None
45
46        m = torch.jit.script(M())
47        self.checkModule(M(), ())
48
49    def test_script_function_attribute(self):
50        @torch.jit.script
51        def fn1(x):
52            return x + x
53
54        @torch.jit.script
55        def fn2(x):
56            return x - x
57
58        class M(torch.nn.Module):
59            def __init__(self, fn):
60                super().__init__()
61                self.fn = fn
62
63            def forward(self, x):
64                return self.fn(x)
65
66        fn1_mod = M(fn1)
67        fn2_mod = M(fn2)
68
69        self.checkModule(fn1_mod, (torch.randn(2, 2),))
70        self.checkModule(fn2_mod, (torch.randn(2, 2),))
71
72    def test_python_function_attribute(self):
73        class M(torch.nn.Module):
74            def __init__(self, fn):
75                super().__init__()
76                self.fn = fn
77
78            def forward(self, x):
79                return self.fn(x)
80
81        mod = M(torch.sigmoid)
82
83        self.checkModule(mod, (torch.randn(2, 2),))
84
85    def test_failed_function_compilation(self):
86        def fn(x):
87            return i_dont_exist  # noqa: F821
88
89        class M(torch.nn.Module):
90            def __init__(self, fn):
91                super().__init__()
92                self.fn = fn
93
94            def forward(self, x):
95                return self.fn(x)
96
97        m = M(fn)
98        with self.assertRaisesRegexWithHighlight(
99            RuntimeError, "failed to compile", "i_dont_exist"
100        ):
101            torch.jit.script(m)
102
103    def test_init_error(self):
104        class M(nn.Module):
105            def __init__(self) -> None:
106                self.x = 2
107
108            def forward(self):
109                pass
110
111        with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
112            torch.jit.script(M())
113
114    def test_script_after_eval(self):
115        class M(nn.Module):
116            def forward(self):
117                if self.training:
118                    return 2
119                else:
120                    return 0
121
122        m = M()
123        sm1 = torch.jit.script(m)
124        m.eval()
125        sm2 = torch.jit.script(m)
126
127        # m is in eval mode, training should be False
128        self.assertFalse(m.training)
129
130        # sm1 was created while m had training = True
131        self.assertTrue(sm1.training)
132        self.assertEqual(sm1.training, sm1._c.getattr("training"))
133        self.assertEqual(sm1(), 2)
134
135        # sm2 was created after m was eval'ed
136        self.assertFalse(sm2.training)
137        self.assertEqual(sm2.training, sm2._c.getattr("training"))
138        self.assertEqual(sm2(), 0)
139
140    def test_module_name(self):
141        class MyModule(torch.nn.Module):
142            def __init__(self) -> None:
143                super().__init__()
144                self.x = 2
145
146            def forward(self, t):
147                return t + self.x
148
149        m = torch.jit.script(MyModule())
150        FileCheck().check("MyModule").run(m.graph)
151
152    def test_repeated_error_stack(self):
153        def d(x):
154            return "a" - 2
155
156        def c(x):
157            return d(x)
158
159        def b(x):
160            return c(x)
161
162        def a(x):
163            return b(x)
164
165        try:
166            torch.jit.script(a)
167        except Exception as e:
168            FileCheck().check_count("is being compiled", 2).run(str(e))
169
170        try:
171            torch.jit.script(a)
172        except Exception as e:
173            # Make sure that no entries are left over from the previous failure
174            FileCheck().check_count("is being compiled", 2).run(str(e))
175
176    def test_constants_with_final(self):
177        class M1(torch.nn.Module):
178            x: torch.jit.Final[int]
179
180            def __init__(self) -> None:
181                super().__init__()
182                self.x = 2
183
184            def forward(self, t):
185                return t + self.x
186
187        self.checkModule(M1(), (torch.randn(2, 2),))
188
189        class M2(torch.nn.Module):
190            x: typing_extensions.Final[int]
191
192            def __init__(self) -> None:
193                super().__init__()
194                self.x = 2
195
196            def forward(self, t):
197                return t + self.x
198
199        self.checkModule(M2(), (torch.randn(2, 2),))
200
201        class M3(torch.nn.Module):
202            x: typing.Final[int]
203
204            def __init__(self) -> None:
205                super().__init__()
206                self.x = 2
207
208            def forward(self, t):
209                return t + self.x
210
211        self.checkModule(M3(), (torch.randn(2, 2),))
212
213    def test_ignore_class(self):
214        @torch.jit.ignore
215        class MyScriptClass:
216            def unscriptable(self):
217                return "a" + 200
218
219        class TestModule(torch.nn.Module):
220            def forward(self, x):
221                return MyScriptClass()
222
223        with self.assertRaisesRegexWithHighlight(
224            torch.jit.frontend.FrontendError,
225            "Cannot instantiate class",
226            "MyScriptClass",
227        ):
228            t = torch.jit.script(TestModule())
229
230    def test_method_call(self):
231        class M(nn.Module):
232            def test(self, x):
233                return x
234
235            def forward(self, z):
236                y = self.test(z)
237                return z + 20 + y
238
239        self.checkModule(M(), (torch.randn(2, 2),))
240
241    def test_module_repr(self):
242        class Submodule(nn.Module):
243            def forward(self, x):
244                return x
245
246        class MyModule(nn.Module):
247            def __init__(self) -> None:
248                super().__init__()
249                self.conv = nn.Conv2d(10, 10, 3)
250                self.lin = nn.Linear(10, 10)
251                self.sub = Submodule()
252
253            def forward(self, x):
254                return self.lin(x) + self.sub(x) + self.conv(x)
255
256        m = torch.jit.script(MyModule())
257
258        with self.capture_stdout() as out:
259            print(m)
260
261        f = FileCheck()
262        f.check("MyModule")
263        f.check("Conv2d")
264        f.check("Linear")
265        f.check("Submodule")
266        f.run(out[0])
267
268        self.assertEqual(m.original_name, "MyModule")
269
270    def test_dir(self):
271        def test_module_dir(mod):
272            dir_set = dir(mod)
273            scripted_mod = torch.jit.script(mod)
274            dir_scripted = set(dir(scripted_mod))
275            # set not currently copied over
276            ignore_set = [
277                "training",
278                "__delitem__",
279                "__setitem__",
280                "clear",
281                "items",
282                "keys",
283                "pop",
284                "update",
285                "values",
286            ]
287            for attr in dir_set:
288                if attr in ignore_set:
289                    continue
290                self.assertTrue(attr in dir_scripted, attr)
291
292        class MyModule(nn.Module):
293            def __init__(self) -> None:
294                super().__init__()
295                self.conv = nn.Conv2d(10, 10, 3)
296                self.lin = nn.Linear(10, 10)
297
298            def forward(self, x):
299                return self.lin(x) + self.conv(x)
300
301        test_module_dir(MyModule())
302
303        # test custom __dir__ for containers
304        conv = nn.Conv2d(10, 10, 3)
305        linear = nn.Linear(10, 10)
306
307        test_module_dir(nn.Sequential(conv, linear))
308        test_module_dir(
309            nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
310        )
311
312    def test_class_compile(self):
313        def other_fn(a: int, b: Tensor) -> Tensor:
314            return a * b
315
316        class B:
317            def __init__(self, x):
318                self.x = 2
319
320            def helper(self, a):
321                return self.x + a + other_fn(self.x, a)
322
323        class N(torch.nn.Module):
324            def forward(self, x):
325                b = B(x)
326                return b.helper(x)
327
328        self.checkModule(N(), (torch.randn(2, 2),))
329
330    def test_error_stack(self):
331        def d(x: int) -> int:
332            return x + 10
333
334        def c(x):
335            return d("hello") + d(x)
336
337        def b(x):
338            return c(x)
339
340        def a(x):
341            return b(x)
342
343        try:
344            scripted = torch.jit.script(a)
345        except RuntimeError as e:
346            checker = FileCheck()
347            checker.check("Expected a value of type 'int'")
348            checker.check("def c(x)")
349            checker.check("def b(x)")
350            checker.check("def a(x)")
351            checker.run(str(e))
352
353    def test_error_stack_module(self):
354        def d(x: int) -> int:
355            return x + 10
356
357        def c(x):
358            return d("hello") + d(x)
359
360        def b(x):
361            return c(x)
362
363        class Submodule(torch.nn.Module):
364            def forward(self, x):
365                return b(x)
366
367        class M(torch.nn.Module):
368            def __init__(self) -> None:
369                super().__init__()
370                self.submodule = Submodule()
371
372            def some_method(self, y):
373                return y + self.submodule(y)
374
375            def forward(self, x):
376                return self.some_method(x)
377
378        try:
379            scripted = torch.jit.script(M())
380        except RuntimeError as e:
381            checker = FileCheck()
382            checker.check("Expected a value of type 'int'")
383            checker.check("'c' is being compiled since it was called from 'b'")
384            checker.check("'b' is being compiled since it was called from")
385            checker.run(str(e))
386
387    @_tmp_donotuse_dont_inline_everything
388    def test_script_basic(self):
389        def a_python_fn(a, b, c):
390            return a + b + c
391
392        @torch.jit.script
393        def a_script_fn(d, e, f):
394            return a_python_fn(d, e, f)
395
396        graph = str(a_script_fn.graph)
397        FileCheck().check("prim::CallFunction").run(graph)
398        FileCheck().check_not("^a_python_fn").run(graph)
399        t = torch.ones(2, 2)
400        self.assertEqual(a_script_fn(t, t, t), t + t + t)
401
402    def test_error_stack_class(self):
403        class X:
404            def bad_fn(self):
405                import pdb  # noqa: F401
406
407        def fn(x) -> X:
408            return X(10)
409
410        try:
411            torch.jit.script(fn)
412        except Exception as e:
413            checker = FileCheck()
414            checker.check("import statements")
415            checker.check("is being compiled since it was called from")
416            checker.run(str(e))
417
418    def test_error_stack_annotation(self):
419        class X:
420            def bad_fn(self):
421                import pdb  # noqa: F401
422
423        def fn(x) -> X:
424            return X(10)
425
426        try:
427            torch.jit.script(fn)
428        except Exception as e:
429            checker = FileCheck()
430            checker.check("import statements")
431            checker.check("is being compiled since it was called from")
432            checker.check("-> X")
433            checker.run(str(e))
434
435    def test_module_basic(self):
436        class Other(torch.nn.Module):
437            __constants__ = ["x"]
438
439            def __init__(self, x):
440                super().__init__()
441                self.x = x
442                self.param = torch.nn.Parameter(torch.ones(2, 2))
443
444            def some_unscriptable_method(self):
445                a = 2
446                a = [2]
447                return a
448
449            def forward(self, t):
450                return t + self.x + self.param
451
452        class M(torch.nn.Module):
453            def __init__(self) -> None:
454                super().__init__()
455                self.other = Other(200)
456
457            def forward(self, t):
458                return self.other(t) * 2
459
460        self.checkModule(M(), (torch.ones(2, 2),))
461
462    def test_module_function_export(self):
463        class Other(torch.nn.Module):
464            __constants__ = ["x"]
465
466            def __init__(self, x):
467                super().__init__()
468                self.x = x
469                self.param = torch.nn.Parameter(torch.ones(2, 2))
470
471            @torch.jit.export
472            def some_entry_point(self, y):
473                return y + 20
474
475            def forward(self, t):
476                return t + self.x + self.param
477
478        class M(torch.nn.Module):
479            def __init__(self) -> None:
480                super().__init__()
481                self.other = Other(200)
482
483            def forward(self, t):
484                return self.other(t) * 2
485
486        self.checkModule(M(), (torch.ones(2, 2),))
487
488    def test_iterable_modules(self):
489        class Inner(torch.nn.Module):
490            def forward(self, x):
491                return x + 10
492
493        class M(torch.nn.Module):
494            def __init__(self) -> None:
495                super().__init__()
496                self.sequential = nn.Sequential(
497                    Inner(), Inner(), nn.Sequential(Inner(), Inner())
498                )
499                self.module_list = nn.ModuleList([Inner(), Inner()])
500
501            def forward(self, x):
502                for mod in self.module_list:
503                    x += mod(x)
504                x += self.sequential(x)
505                return x
506
507        self.checkModule(M(), (torch.randn(5, 5),))
508
509    def test_prepare_scriptable_basic(self):
510        class SeluButReluWhenScripted(torch.nn.SELU):
511            def __prepare_scriptable__(self):
512                return nn.ReLU()
513
514        t = torch.randn(5, 5)
515        m = SeluButReluWhenScripted()
516        sm = torch.jit.script(m)
517        eager_out = m(t)
518        script_out = sm(t)
519        self.assertNotEqual(eager_out, script_out)
520
521    def test_prepare_scriptable_iterable_modules(self):
522        class SeluButReluWhenScripted(torch.nn.SELU):
523            def __prepare_scriptable__(self):
524                return nn.ReLU()
525
526        class M(torch.nn.Module):
527            def __init__(self) -> None:
528                super().__init__()
529                shared = SeluButReluWhenScripted()
530                self.sequential = nn.Sequential(
531                    SeluButReluWhenScripted(),
532                    SeluButReluWhenScripted(),
533                    nn.Sequential(
534                        SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
535                    ),
536                    shared,
537                )
538                self.module_list = nn.ModuleList(
539                    [SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
540                )
541
542            def forward(self, x):
543                for mod in self.module_list:
544                    x += mod(x)
545                x += self.sequential(x)
546                return x
547
548        t = torch.randn(5, 5)
549        m = M()
550        eager_out = m(t.clone())
551        sm = torch.jit.script(m)
552        script_out = sm(t.clone())
553        self.assertNotEqual(eager_out, script_out)
554
555    def test_prepare_scriptable_cycle(self):
556        t = torch.randn(5, 5)
557        c = torch.nn.Module()
558        p = torch.nn.Module()
559        c.__dict__["_p"] = p
560        p.__dict__["_c"] = c
561
562        sm = torch.jit.script(p)
563
564    def test_prepare_scriptable_escape_hatch(self):
565        class NonJitableClass:
566            def __call__(self, int1, int2, *args):
567                total = int1 + int2
568                for arg in args:
569                    total += arg
570                return total
571
572        obj = NonJitableClass()
573
574        self.assertEqual(obj(1, 2), 3)
575        self.assertEqual(obj(1, 2, 3, 4), 10)
576        with self.assertRaisesRegex(
577            torch.jit.frontend.NotSupportedError,
578            expected_regex="can't take variable number of arguments",
579        ):
580            torch.jit.script(obj)
581
582        def escape_hatch(int1: int, int2: int) -> int:
583            return int1 + int2
584
585        class NonJitableClassWithEscapeHatch(NonJitableClass):
586            def __prepare_scriptable__(self):
587                return escape_hatch
588
589        jit_obj = torch.jit.script(NonJitableClassWithEscapeHatch())
590
591        self.assertEqual(jit_obj(1, 2), 3)
592        with self.assertRaisesRegex(
593            RuntimeError,
594            expected_regex=re.escape(
595                "expected at most 2 argument(s) but received 4 argument(s)"
596            ),
597        ):
598            jit_obj(1, 2, 3, 4)
599
600    def test_attributes(self):
601        @torch.jit.script
602        class Inner2:
603            def __init__(self) -> None:
604                self.b = "a string"
605
606        @torch.jit.script
607        class Foo:
608            def __init__(self) -> None:
609                self.a = 4
610                self.inner = Inner2()
611
612        @torch.jit.script
613        class SFoo:
614            def __init__(self) -> None:
615                self.a = 4
616                self.inner = Inner2()
617
618            def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
619                a, inner = obj
620                self.a = a
621                self.inner = inner
622
623            def __getstate__(self):
624                return (self.a, self.inner)
625
626        untyped_values = (
627            ("my_dict", {"I": "am", "a test": "test"}),
628            ("my_float", 2.3),
629            ("my_int", 99),
630            ("my_bool", False),
631            ("my_tuple", (1, 2, 3, 4)),
632            ("my_list", [(1, 2), (3, 4)]),
633            # ('my_tensor', torch.randn(2, 2)),
634            ("my_int_list", [1, 2, 3, 4]),
635            # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
636            ("my_bool_list", [True, True, False, True]),
637            ("my_float_list", [1.0, 2.0, 3.0, 4.0]),
638            ("my_str_list", ["hello", "bye"]),
639        )
640        typed_values = (
641            ("my_empty_list", []),
642            ("my_empty_dict", {}),
643            ("my_none", None),
644            ("my_object", Foo()),
645            ("my_object2", SFoo()),
646        )
647
648        class M(torch.nn.Module):
649            # TODO: re-enable this once this test is in a Python 3-only syntax
650            # file
651            # my_empty_list : List[int]
652            # my_empty_dict : Dict[str, int]
653            # my_none : Optional[int]
654
655            def forward(self, x):
656                return (
657                    self.my_dict,
658                    self.my_float,
659                    self.my_int,
660                    self.my_bool,
661                    # self.my_tensor,
662                    self.my_int_list,
663                    # self.my_tensor_list,
664                    self.my_bool_list,
665                    self.my_float_list,
666                    self.my_str_list,
667                    self.my_empty_list,
668                    self.my_empty_dict,
669                    self.my_none,
670                    self.my_object.a,
671                    self.my_object.inner.b,
672                    self.my_object.a,
673                    self.my_object2.inner.b,
674                )
675
676        # TODO: as a followup, fix this test
677        # We can't define class attributes like we should be doing:
678        #   class M(torch.nn.Module):
679        #       my_empty_list : List[int]
680        #       my_empty_dict : Dict[str, int]
681        #       my_none : Optional[int]
682        #       my_out_of_line_attribute: List[int] = [1, 2, 3]
683        # since there's no string frontend for Python classes (so the `define`)
684        # trick doesn't work.
685        M.__annotations__ = {
686            "my_empty_list": List[int],
687            "my_empty_dict": Dict[str, int],
688            "my_none": Optional[int],
689            "my_object": Foo,
690            "my_object2": SFoo,
691        }
692
693        m = M()
694        for name, value in untyped_values + typed_values:
695            setattr(m, name, value)
696
697        self.checkModule(m, (torch.randn(5, 5),))
698
699    def test_function_attribute_in_submodule(self):
700        class N(nn.Module):
701            def __init__(self, norm):
702                super().__init__()
703                self.activation = torch.nn.functional.relu
704                self.norm = norm
705
706            def forward(self, src):
707                output = src
708                output = self.norm(output)
709                return output
710
711        class M(nn.Module):
712            def __init__(self) -> None:
713                super().__init__()
714                encoder_norm = nn.ReLU()
715                self.encoder = N(encoder_norm)
716
717            def forward(self, x):
718                return self.encoder(x)
719
720        m = M()
721        self.checkModule(m, (torch.randn(5, 5),))
722
723    def test_inner_traced_module(self):
724        class Dummy(nn.Module):
725            def forward(self, x):
726                return x
727
728        class Model(nn.Module):
729            def __init__(self, dummies):
730                super().__init__()
731                self._dummies = dummies
732
733            def forward(self, x):
734                out = []
735                for dummy in self._dummies:
736                    out.append(dummy(x))
737                return out
738
739        dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
740        dummies = nn.ModuleList([dummy])
741        model = Model(dummies)
742        self.checkModule(model, (torch.rand(5, 5),))
743
744    def test_script_loaded_module(self):
745        """
746        Test that we can hold a loaded ScriptModule as a submodule.
747        """
748
749        class Dummy(nn.Module):
750            def forward(self, x):
751                return x
752
753        dummy = torch.jit.script(Dummy())
754        dummy = self.getExportImportCopy(dummy)
755
756        class ContainsLoaded(torch.nn.Module):
757            def __init__(self) -> None:
758                super().__init__()
759                self.encoder = dummy
760
761            def forward(self, input):
762                return self.encoder(input)
763
764        self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
765
766    def test_optional_module(self):
767        class Dummy(nn.Module):
768            def __init__(self) -> None:
769                super().__init__()
770                self.foo = nn.Linear(2, 2)
771
772            def forward(self, x):
773                if self.foo is not None:
774                    return self.foo(x)
775                return x
776
777        mod = Dummy()
778        self.checkModule(mod, (torch.rand(2, 2),))
779        mod.foo = None
780        self.checkModule(mod, (torch.rand(2, 2),))
781
782    def test_override_instance_method_ignore(self):
783        class M(torch.nn.Module):
784            @torch.jit.ignore
785            def i_am_ignored(self):
786                return "old"
787
788        m = M()
789
790        # Override the ignored method by binding a new method to this instance.
791        @torch.jit.ignore
792        def i_am_ignored(self):
793            return "new"
794
795        m.i_am_ignored = types.MethodType(i_am_ignored, m)
796        self.assertEqual(m.i_am_ignored(), "new")
797
798        # ScriptModule should correctly reflect the override.
799        s = torch.jit.script(m)
800        self.assertEqual(s.i_am_ignored(), "new")
801