xref: /aosp_15_r20/external/pytorch/test/jit/test_module_containers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from collections import OrderedDict
6from typing import Any, List, Tuple
7
8import torch
9import torch.nn as nn
10from torch.testing._internal.jit_utils import JitTestCase
11
12
13# Make the helper files in test/ importable
14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
15sys.path.append(pytorch_test_dir)
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestModuleContainers(JitTestCase):
26    def test_sequential_intermediary_types(self):
27        class A(torch.nn.Module):
28            def forward(self, x):
29                return x + 3
30
31        class B(torch.nn.Module):
32            def forward(self, x):
33                return {"1": x}
34
35        class C(torch.nn.Module):
36            def __init__(self) -> None:
37                super().__init__()
38                self.foo = torch.nn.Sequential(A(), B())
39
40            def forward(self, x):
41                return self.foo(x)
42
43        self.checkModule(C(), (torch.tensor(1),))
44
45    def test_moduledict(self):
46        class Inner(torch.nn.Module):
47            def forward(self, x):
48                return x + 10
49
50        class Inner2(torch.nn.Module):
51            def forward(self, x):
52                return x * 2
53
54        class Inner3(torch.nn.Module):
55            def forward(self, x):
56                return (x - 4) * 3
57
58        class M(torch.nn.Module):
59            def __init__(self) -> None:
60                super().__init__()
61                modules = OrderedDict(
62                    [
63                        ("one", Inner()),
64                        ("two", Inner2()),
65                        ("three", Inner3()),
66                    ]
67                )
68                self.moduledict = nn.ModuleDict(modules)
69
70            def forward(self, x, skip_name):
71                # type: (Tensor, str)
72                names = torch.jit.annotate(List[str], [])
73                values = []
74                for name in self.moduledict:
75                    names.append(name)
76
77                for name, mod in self.moduledict.items():
78                    if name != skip_name:
79                        names.append(name)
80                        x = mod(x)
81                        values.append(x)
82
83                for mod in self.moduledict.values():
84                    x = mod(x)
85                    values.append(x)
86
87                for key in self.moduledict.keys():
88                    names.append(key)
89
90                return x, names
91
92        class M2(M):
93            def forward(self, x, skip_name):
94                # type: (Tensor, str)
95                names = torch.jit.annotate(List[str], [])
96                values = []
97                x2 = x
98                iter = 0
99                for name in self.moduledict:
100                    names.append(name)
101
102                for i, (name, mod) in enumerate(self.moduledict.items()):
103                    iter += i
104                    if name != skip_name:
105                        names.append(name)
106                        x = mod(x)
107                        values.append(x)
108
109                for i, mod in enumerate(self.moduledict.values()):
110                    iter += i
111                    x = mod(x)
112                    values.append(x)
113
114                for i, key in enumerate(self.moduledict.keys()):
115                    iter += i
116                    names.append(key)
117
118                for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
119                    iter += i
120                    x2 = mod(mod(x2))
121
122                return x, x2, names, iter
123
124        for name in ["", "one", "two", "three"]:
125            inp = torch.tensor(1)
126            self.checkModule(M(), (inp, name))
127            self.checkModule(M2(), (inp, name))
128
129    def test_custom_container_forward(self):
130        class Inner(torch.nn.Module):
131            def forward(self, x):
132                return x + 10
133
134        class CustomSequential(nn.Sequential):
135            def __init__(self) -> None:
136                super().__init__(nn.ReLU(), Inner())
137
138            def forward(self, x):
139                x = x + 3
140                for mod in self:
141                    x = mod(x)
142                return x - 5
143
144        self.checkModule(CustomSequential(), (torch.tensor(0.5),))
145
146        class CustomModuleList(nn.ModuleList):
147            def __init__(self) -> None:
148                super().__init__([nn.ReLU(), Inner()])
149
150            def forward(self, x):
151                x = x + 3
152                for mod in self:
153                    x = mod(x)
154                return x - 5
155
156        self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
157
158        class CustomModuleDict(nn.ModuleDict):
159            def __init__(self) -> None:
160                super().__init__(
161                    OrderedDict(
162                        [
163                            ("one", Inner()),
164                            ("two", nn.ReLU()),
165                            ("three", Inner()),
166                        ]
167                    )
168                )
169
170            def forward(self, x):
171                x = x + 3
172                names = torch.jit.annotate(List[str], [])
173                for name, mod in self.items():
174                    x = mod(x)
175                    names.append(name)
176                return names, x - 5
177
178        self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
179
180    def test_script_module_list_sequential(self):
181        class M(torch.jit.ScriptModule):
182            def __init__(self, mod_list):
183                super().__init__()
184                self.mods = mod_list
185
186            @torch.jit.script_method
187            def forward(self, v):
188                for m in self.mods:
189                    v = m(v)
190                return v
191
192        with torch.jit.optimized_execution(False):
193            m = M(nn.Sequential(nn.ReLU()))
194            self.assertExportImportModule(m, (torch.randn(2, 2),))
195
196    def test_script_modulelist_index(self):
197        class Sub(torch.nn.Module):
198            def __init__(self, i):
199                super().__init__()
200                self.i = i
201
202            def forward(self, thing):
203                return thing - self.i
204
205        class M(torch.nn.Module):
206            def __init__(self) -> None:
207                super().__init__()
208                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
209
210            def forward(self, v):
211                v = self.mods[4].forward(v)
212                v = self.mods[-1].forward(v)
213                v = self.mods[-9].forward(v)
214                return v
215
216        x = torch.tensor(1)
217        self.checkModule(M(), (x,))
218
219        class MForward(torch.nn.Module):
220            def __init__(self) -> None:
221                super().__init__()
222                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
223
224            def forward(self, v):
225                v = self.mods[4](v)
226                v = self.mods[-1](v)
227                v = self.mods[-9](v)
228                return v
229
230        self.checkModule(MForward(), (torch.tensor(1),))
231
232        class M2(M):
233            def forward(self, v):
234                return self.mods[-11].forward(v)
235
236        with self.assertRaisesRegexWithHighlight(
237            Exception, "Index -11 out of range", "self.mods[-11]"
238        ):
239            torch.jit.script(M2())
240
241        class M3(M):
242            def forward(self, v):
243                i = 3
244                return self.mods[i].forward(v)
245
246        with self.assertRaisesRegexWithHighlight(
247            Exception, "Enumeration is supported", "self.mods[i]"
248        ):
249            torch.jit.script(M3())
250
251        class M4(M):
252            def forward(self, v):
253                i = 3
254                return self.mods[i].forward(v)
255
256        with self.assertRaisesRegex(Exception, "will fail because i is not a literal"):
257            torch.jit.script(M4())
258
259    def test_module_interface_special_methods(self):
260        class CustomModuleInterface(torch.nn.Module):
261            pass
262
263        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
264            def __init__(self, modules=None):
265                CustomModuleInterface.__init__(self)
266                torch.nn.ModuleList.__init__(self, modules)
267
268        class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
269            def __init__(self, modules=None):
270                CustomModuleInterface.__init__(self)
271                torch.nn.Sequential.__init__(self, modules)
272
273        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
274            def __init__(self, modules=None):
275                CustomModuleInterface.__init__(self)
276                torch.nn.ModuleDict.__init__(self, modules)
277
278        class MyModule(torch.nn.Module):
279            def __init__(self) -> None:
280                super().__init__()
281                # work around aliasing issue for 'is' operator by scripting ReLU up front
282                self.submod = torch.jit.script(torch.nn.ReLU())
283                self.modulelist = CustomModuleList([self.submod])
284                self.sequential = CustomSequential(self.submod)
285                self.moduledict = CustomModuleDict({"submod": self.submod})
286
287            def forward(self, inputs):
288                assert (
289                    self.modulelist[0] is self.submod
290                ), "__getitem__ failing for ModuleList"
291                assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
292                for module in self.modulelist:
293                    assert module is self.submod, "__iter__ failing for ModuleList"
294
295                assert (
296                    self.sequential[0] is self.submod
297                ), "__getitem__ failing for Sequential"
298                assert len(self.sequential) == 1, "__len__ failing for Sequential"
299                for module in self.sequential:
300                    assert module is self.submod, "__iter__ failing for Sequential"
301
302                assert (
303                    self.moduledict["submod"] is self.submod
304                ), "__getitem__ failing for ModuleDict"
305                assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
306
307                # note: unable to index moduledict with a string variable currently
308                i = 0
309                for key in self.moduledict:
310                    i += 1
311                assert i == len(self.moduledict), "iteration failing for ModuleDict"
312
313                assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
314
315                for key in self.moduledict.keys():
316                    assert key == "submod", "keys() fails for ModuleDict"
317
318                for item in self.moduledict.items():
319                    assert item[0] == "submod", "items() fails for ModuleDict"
320                    assert item[1] is self.submod, "items() fails for ModuleDict"
321
322                for value in self.moduledict.values():
323                    assert value is self.submod, "values() fails for ModuleDict"
324
325                return inputs
326
327        m = MyModule()
328        self.checkModule(m, [torch.randn(2, 2)])
329
330    def test_special_method_with_override(self):
331        class CustomModuleInterface(torch.nn.Module):
332            pass
333
334        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
335            def __init__(self, modules=None):
336                CustomModuleInterface.__init__(self)
337                torch.nn.ModuleList.__init__(self, modules)
338
339            def __len__(self):
340                # this is arbitrary, just to check that the overridden py __len__ from
341                # CustomModuleList takes precedence over the automatically generated
342                # __len__ added by the jit compiler
343                return 2
344
345        class MyModule(torch.nn.Module):
346            def __init__(self) -> None:
347                super().__init__()
348                # work around aliasing issue for 'is' operator by scripting ReLU up front
349                self.submod = torch.jit.script(torch.nn.ReLU())
350                self.modulelist = CustomModuleList([self.submod])
351
352            def forward(self, inputs):
353                assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
354                return inputs
355
356        m = MyModule()
357        self.checkModule(m, [torch.randn(2, 2)])
358        mm = torch.jit.script(m)
359
360    def test_moduledict_getitem(self):
361        class MyModule(torch.nn.Module):
362            def __init__(self) -> None:
363                super().__init__()
364                self.relu = torch.jit.script(torch.nn.ReLU())
365                self.tanh = torch.jit.script(torch.nn.Tanh())
366                self.moduledict = torch.nn.ModuleDict(
367                    {"relu": self.relu, "tanh": self.tanh}
368                )
369
370            def forward(self, input):
371                assert self.moduledict["relu"] is self.relu
372                assert self.moduledict["tanh"] is self.tanh
373                return input
374
375        m = MyModule()
376        self.checkModule(m, [torch.randn(2, 2)])
377
378    def test_moduledict_keyerror(self):
379        class BadModule(torch.nn.Module):
380            def __init__(self) -> None:
381                super().__init__()
382                self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
383
384            def forward(self, input):
385                assert self.moduledict["blah"] == "blah", "this is a keyerror"
386
387        with self.assertRaisesRegexWithHighlight(
388            RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
389        ):
390            b = BadModule()
391            torch.jit.script(b)
392
393        class AnotherBadModule(torch.nn.Module):
394            def __init__(self) -> None:
395                super().__init__()
396                self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
397
398            def forward(self, input):
399                idx = "blah"
400                assert self.moduledict[idx] == "blah", "this is a string literal error"
401
402        with self.assertRaisesRegexWithHighlight(
403            RuntimeError,
404            "Unable to extract string literal index. "
405            "ModuleDict indexing is only supported with string literals. "
406            "For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
407            "because i is not a literal.",
408            "self.moduledict[idx]",
409        ):
410            b = AnotherBadModule()
411            torch.jit.script(b)
412
413    def test_normal_list_attribute_with_modules_error(self):
414        """
415        Test that an attempt to script a module with a regular list attribute
416        containing other modules fails with a relevant error message.
417        """
418
419        class Mod(torch.nn.Module):
420            def __init__(self) -> None:
421                super().__init__()
422                self.a = [torch.nn.ReLU(), torch.nn.ReLU()]
423
424            def forward(self):
425                return len(self.a)
426
427        error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module"
428        with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"):
429            torch.jit.script(Mod())
430
431    def test_empty_dict_override_contains(self):
432        class CustomModuleInterface(torch.nn.Module):
433            pass
434
435        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
436            def __init__(self, modules=None):
437                CustomModuleInterface.__init__(self)
438                torch.nn.ModuleDict.__init__(self, modules)
439
440        class MyModule(torch.nn.Module):
441            def __init__(self) -> None:
442                super().__init__()
443                # work around aliasing issue for 'is' operator by scripting ReLU up front
444                self.submod = torch.jit.script(torch.nn.ReLU())
445                self.moduledict = CustomModuleDict()
446
447            def forward(self, inputs):
448                assert (
449                    "submod" not in self.moduledict
450                ), "__contains__ fails for ModuleDict"
451                return inputs
452
453        m = MyModule()
454        self.checkModule(m, [torch.randn(2, 2)])
455
456    def test_typed_module_dict(self):
457        """
458        Test that a type annotation can be provided for a ModuleDict that allows
459        non-static indexing.
460        """
461
462        @torch.jit.interface
463        class ModuleInterface(torch.nn.Module):
464            def forward(self, inp: Any) -> Any:
465                pass
466
467        class ImplementsInterface(torch.nn.Module):
468            def forward(self, inp: Any) -> Any:
469                if isinstance(inp, torch.Tensor):
470                    return torch.max(inp, dim=0)
471
472                return inp
473
474        class DoesNotImplementInterface(torch.nn.Module):
475            def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
476                return torch.max(inp, dim=0)
477
478        # Test annotation of submodule.
479        class Mod(torch.nn.Module):
480            def __init__(self) -> None:
481                super().__init__()
482                self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
483
484            def forward(self, x: torch.Tensor, key: str) -> Any:
485                value: ModuleInterface = self.d[key]
486                return value.forward(x)
487
488        m = Mod()
489        self.checkModule(m, (torch.randn(2, 2), "module"))
490
491        # Test annotation of self.
492        class ModDict(torch.nn.ModuleDict):
493            def __init__(self) -> None:
494                super().__init__({"module": ImplementsInterface()})
495
496            def forward(self, x: torch.Tensor, key: str) -> Any:
497                submodule: ModuleInterface = self[key]
498                return submodule.forward(x)
499
500        m = ModDict()
501        self.checkModule(m, (torch.randn(2, 2), "module"))
502
503        # Test error message thrown when annotated attribute does not comply with the
504        # annotation.
505        class ModWithWrongAnnotation(torch.nn.ModuleDict):
506            def __init__(self) -> None:
507                super().__init__()
508                self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()})
509
510            def forward(self, x: torch.Tensor, key: str) -> Any:
511                submodule: ModuleInterface = self.d[key]
512                return submodule.forward(x)
513
514        with self.assertRaisesRegexWithHighlight(
515            RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
516        ):
517            torch.jit.script(ModWithWrongAnnotation())
518
519    def test_typed_module_list(self):
520        """
521        Test that a type annotation can be provided for a ModuleList that allows
522        non-static indexing.
523        """
524
525        @torch.jit.interface
526        class ModuleInterface(torch.nn.Module):
527            def forward(self, inp: Any) -> Any:
528                pass
529
530        class ImplementsInterface(torch.nn.Module):
531            def forward(self, inp: Any) -> Any:
532                if isinstance(inp, torch.Tensor):
533                    return torch.max(inp, dim=0)
534
535                return inp
536
537        class DoesNotImplementInterface(torch.nn.Module):
538            def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
539                return torch.max(inp, dim=0)
540
541        # Test annotation of submodule.
542        class Mod(torch.nn.Module):
543            def __init__(self) -> None:
544                super().__init__()
545                self.l = torch.nn.ModuleList([ImplementsInterface()])
546
547            def forward(self, x: torch.Tensor, idx: int) -> Any:
548                value: ModuleInterface = self.l[idx]
549                return value.forward(x)
550
551        m = Mod()
552        self.checkModule(m, (torch.randn(2, 2), 0))
553
554        # Test annotation of self.
555        class ModList(torch.nn.ModuleList):
556            def __init__(self) -> None:
557                super().__init__([ImplementsInterface()])
558
559            def forward(self, x: torch.Tensor, idx: int) -> Any:
560                submodule: ModuleInterface = self[idx]
561                return submodule.forward(x)
562
563        m = ModList()
564        self.checkModule(m, (torch.randn(2, 2), 0))
565
566        # Test error message thrown when annotated attribute does not comply with the
567        # annotation.
568        class ModWithWrongAnnotation(torch.nn.ModuleList):
569            def __init__(self) -> None:
570                super().__init__()
571                self.l = torch.nn.ModuleList([DoesNotImplementInterface()])
572
573            def forward(self, x: torch.Tensor, idx: int) -> Any:
574                submodule: ModuleInterface = self.l[idx]
575                return submodule.forward(x)
576
577        with self.assertRaisesRegexWithHighlight(
578            RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
579        ):
580            torch.jit.script(ModWithWrongAnnotation())
581
582    def test_module_properties(self):
583        class ModuleWithProperties(torch.nn.Module):
584            __jit_unused_properties__ = ["ignored_attr"]
585
586            def __init__(self, a: int):
587                super().__init__()
588                self.a = a
589
590            def forward(self, a: int, b: int):
591                self.attr = a + b
592                return self.attr
593
594            @property
595            def attr(self):
596                return self.a
597
598            @property
599            def ignored_attr(self):
600                return sum([self.a])
601
602            @torch.jit.unused
603            @property
604            def ignored_attr_2(self):
605                return sum([self.a])
606
607            @ignored_attr_2.setter
608            def ignored_attr_2(self, value):
609                self.a = sum([self.a])
610
611            @attr.setter
612            def attr(self, a: int):
613                if a > 0:
614                    self.a = a
615                else:
616                    self.a = 0
617
618        class ModuleWithNoSetter(torch.nn.Module):
619            def __init__(self, a: int):
620                super().__init__()
621                self.a = a
622
623            def forward(self, a: int, b: int):
624                self.attr + a + b
625
626            @property
627            def attr(self):
628                return self.a + 1
629
630        self.checkModule(
631            ModuleWithProperties(5),
632            (
633                5,
634                6,
635            ),
636        )
637        self.checkModule(
638            ModuleWithProperties(5),
639            (
640                -5,
641                -6,
642            ),
643        )
644        self.checkModule(
645            ModuleWithNoSetter(5),
646            (
647                5,
648                6,
649            ),
650        )
651        self.checkModule(
652            ModuleWithNoSetter(5),
653            (
654                -5,
655                -6,
656            ),
657        )
658
659        mod = ModuleWithProperties(3)
660        scripted_mod = torch.jit.script(mod)
661
662        with self.assertRaisesRegex(AttributeError, "has no attribute"):
663            scripted_mod.ignored_attr
664
665    def test_module_inplace_construct(self):
666        class M(nn.Module):
667            def __init__(self, start: int):
668                super().__init__()
669                self.linear = nn.Linear(3, 3)
670                self.attribute = start
671                self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
672
673            def method(self) -> int:
674                return self.attribute
675
676            @torch.jit.unused
677            def unused_method(self):
678                return self.attribute + self.attribute
679
680            def forward(self, x):
681                return self.linear(self.linear(x))
682
683        class N(nn.Module):
684            def __init__(self) -> None:
685                super().__init__()
686                self.linear = nn.Linear(4, 4)
687
688            @torch.jit.ignore
689            def ignored_method(self, x):
690                return x
691
692            def forward(self, x):
693                return self.linear(x)
694
695        m = torch.jit.script(M(3))
696        n = torch.jit.script(N())
697
698        n._reconstruct(m._c)
699
700        inp = torch.rand((3))
701
702        # Check that both modules produce the same output.
703        with torch.no_grad():
704            m_out = m(inp)
705            n_out = n(inp)
706            self.assertEqual(m_out, n_out)
707
708        # Check that ignored method is still intact.
709        self.assertEqual(inp, n.ignored_method(inp))
710
711    def test_parameterlist_script_getitem(self):
712        class MyModule(nn.Module):
713            def __init__(self) -> None:
714                super().__init__()
715                self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
716                self.parameter_list = nn.ParameterList(
717                    [nn.Parameter(torch.zeros(1)) for _ in range(10)]
718                )
719
720            def forward(self, x):
721                self.module_list[0]
722                self.parameter_list[0]
723                return x
724
725        self.checkModule(MyModule(), (torch.zeros(1)))
726
727    def test_parameterlist_script_iter(self):
728        class MyModule(nn.Module):
729            def __init__(self) -> None:
730                super().__init__()
731                self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
732                self.parameter_list = nn.ParameterList(
733                    [nn.Parameter(torch.zeros(1)) for _ in range(10)]
734                )
735
736            def forward(self, x):
737                r = x
738                for i, p in enumerate(self.parameter_list):
739                    r = r + p + i
740                return r
741
742        self.checkModule(MyModule(), (torch.zeros(1),))
743
744    def test_parameterdict_script_getitem(self):
745        class MyModule(nn.Module):
746            def __init__(self) -> None:
747                super().__init__()
748                self.parameter_dict = nn.ParameterDict(
749                    {k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
750                )
751
752            def forward(self, x):
753                return (
754                    self.parameter_dict["a"] * x
755                    + self.parameter_dict["b"] * self.parameter_dict["c"]
756                )
757
758        self.checkModule(MyModule(), (torch.ones(1),))
759