xref: /aosp_15_r20/external/pytorch/test/jit/test_module_apis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import Any, Dict, List
6
7import torch
8from torch.testing._internal.jit_utils import JitTestCase
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14
15if __name__ == "__main__":
16    raise RuntimeError(
17        "This test file is not meant to be run directly, use:\n\n"
18        "\tpython test/test_jit.py TESTNAME\n\n"
19        "instead."
20    )
21
22
23class TestModuleAPIs(JitTestCase):
24    def test_default_state_dict_methods(self):
25        """Tests that default state dict methods are automatically available"""
26
27        class DefaultStateDictModule(torch.nn.Module):
28            def __init__(self) -> None:
29                super().__init__()
30                self.conv = torch.nn.Conv2d(6, 16, 5)
31                self.fc = torch.nn.Linear(16 * 5 * 5, 120)
32
33            def forward(self, x):
34                x = self.conv(x)
35                x = self.fc(x)
36                return x
37
38        m1 = torch.jit.script(DefaultStateDictModule())
39        m2 = torch.jit.script(DefaultStateDictModule())
40        state_dict = m1.state_dict()
41        m2.load_state_dict(state_dict)
42
43    def test_customized_state_dict_methods(self):
44        """Tests that customized state dict methods are in effect"""
45
46        class CustomStateDictModule(torch.nn.Module):
47            def __init__(self) -> None:
48                super().__init__()
49                self.conv = torch.nn.Conv2d(6, 16, 5)
50                self.fc = torch.nn.Linear(16 * 5 * 5, 120)
51                self.customized_save_state_dict_called: bool = False
52                self.customized_load_state_dict_called: bool = False
53
54            def forward(self, x):
55                x = self.conv(x)
56                x = self.fc(x)
57                return x
58
59            @torch.jit.export
60            def _save_to_state_dict(
61                self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
62            ):
63                self.customized_save_state_dict_called = True
64                return {"dummy": torch.ones(1)}
65
66            @torch.jit.export
67            def _load_from_state_dict(
68                self,
69                state_dict: Dict[str, torch.Tensor],
70                prefix: str,
71                local_metadata: Any,
72                strict: bool,
73                missing_keys: List[str],
74                unexpected_keys: List[str],
75                error_msgs: List[str],
76            ):
77                self.customized_load_state_dict_called = True
78                return
79
80        m1 = torch.jit.script(CustomStateDictModule())
81        self.assertFalse(m1.customized_save_state_dict_called)
82        state_dict = m1.state_dict()
83        self.assertTrue(m1.customized_save_state_dict_called)
84
85        m2 = torch.jit.script(CustomStateDictModule())
86        self.assertFalse(m2.customized_load_state_dict_called)
87        m2.load_state_dict(state_dict)
88        self.assertTrue(m2.customized_load_state_dict_called)
89
90    def test_submodule_customized_state_dict_methods(self):
91        """Tests that customized state dict methods on submodules are in effect"""
92
93        class CustomStateDictModule(torch.nn.Module):
94            def __init__(self) -> None:
95                super().__init__()
96                self.conv = torch.nn.Conv2d(6, 16, 5)
97                self.fc = torch.nn.Linear(16 * 5 * 5, 120)
98                self.customized_save_state_dict_called: bool = False
99                self.customized_load_state_dict_called: bool = False
100
101            def forward(self, x):
102                x = self.conv(x)
103                x = self.fc(x)
104                return x
105
106            @torch.jit.export
107            def _save_to_state_dict(
108                self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
109            ):
110                self.customized_save_state_dict_called = True
111                return {"dummy": torch.ones(1)}
112
113            @torch.jit.export
114            def _load_from_state_dict(
115                self,
116                state_dict: Dict[str, torch.Tensor],
117                prefix: str,
118                local_metadata: Any,
119                strict: bool,
120                missing_keys: List[str],
121                unexpected_keys: List[str],
122                error_msgs: List[str],
123            ):
124                self.customized_load_state_dict_called = True
125                return
126
127        class ParentModule(torch.nn.Module):
128            def __init__(self) -> None:
129                super().__init__()
130                self.sub = CustomStateDictModule()
131
132            def forward(self, x):
133                return self.sub(x)
134
135        m1 = torch.jit.script(ParentModule())
136        self.assertFalse(m1.sub.customized_save_state_dict_called)
137        state_dict = m1.state_dict()
138        self.assertTrue(m1.sub.customized_save_state_dict_called)
139
140        m2 = torch.jit.script(ParentModule())
141        self.assertFalse(m2.sub.customized_load_state_dict_called)
142        m2.load_state_dict(state_dict)
143        self.assertTrue(m2.sub.customized_load_state_dict_called)
144