1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import Any, Dict, List 6 7import torch 8from torch.testing._internal.jit_utils import JitTestCase 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14 15if __name__ == "__main__": 16 raise RuntimeError( 17 "This test file is not meant to be run directly, use:\n\n" 18 "\tpython test/test_jit.py TESTNAME\n\n" 19 "instead." 20 ) 21 22 23class TestModuleAPIs(JitTestCase): 24 def test_default_state_dict_methods(self): 25 """Tests that default state dict methods are automatically available""" 26 27 class DefaultStateDictModule(torch.nn.Module): 28 def __init__(self) -> None: 29 super().__init__() 30 self.conv = torch.nn.Conv2d(6, 16, 5) 31 self.fc = torch.nn.Linear(16 * 5 * 5, 120) 32 33 def forward(self, x): 34 x = self.conv(x) 35 x = self.fc(x) 36 return x 37 38 m1 = torch.jit.script(DefaultStateDictModule()) 39 m2 = torch.jit.script(DefaultStateDictModule()) 40 state_dict = m1.state_dict() 41 m2.load_state_dict(state_dict) 42 43 def test_customized_state_dict_methods(self): 44 """Tests that customized state dict methods are in effect""" 45 46 class CustomStateDictModule(torch.nn.Module): 47 def __init__(self) -> None: 48 super().__init__() 49 self.conv = torch.nn.Conv2d(6, 16, 5) 50 self.fc = torch.nn.Linear(16 * 5 * 5, 120) 51 self.customized_save_state_dict_called: bool = False 52 self.customized_load_state_dict_called: bool = False 53 54 def forward(self, x): 55 x = self.conv(x) 56 x = self.fc(x) 57 return x 58 59 @torch.jit.export 60 def _save_to_state_dict( 61 self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool 62 ): 63 self.customized_save_state_dict_called = True 64 return {"dummy": torch.ones(1)} 65 66 @torch.jit.export 67 def _load_from_state_dict( 68 self, 69 state_dict: Dict[str, torch.Tensor], 70 prefix: str, 71 local_metadata: Any, 72 strict: bool, 73 missing_keys: List[str], 74 unexpected_keys: List[str], 75 error_msgs: List[str], 76 ): 77 self.customized_load_state_dict_called = True 78 return 79 80 m1 = torch.jit.script(CustomStateDictModule()) 81 self.assertFalse(m1.customized_save_state_dict_called) 82 state_dict = m1.state_dict() 83 self.assertTrue(m1.customized_save_state_dict_called) 84 85 m2 = torch.jit.script(CustomStateDictModule()) 86 self.assertFalse(m2.customized_load_state_dict_called) 87 m2.load_state_dict(state_dict) 88 self.assertTrue(m2.customized_load_state_dict_called) 89 90 def test_submodule_customized_state_dict_methods(self): 91 """Tests that customized state dict methods on submodules are in effect""" 92 93 class CustomStateDictModule(torch.nn.Module): 94 def __init__(self) -> None: 95 super().__init__() 96 self.conv = torch.nn.Conv2d(6, 16, 5) 97 self.fc = torch.nn.Linear(16 * 5 * 5, 120) 98 self.customized_save_state_dict_called: bool = False 99 self.customized_load_state_dict_called: bool = False 100 101 def forward(self, x): 102 x = self.conv(x) 103 x = self.fc(x) 104 return x 105 106 @torch.jit.export 107 def _save_to_state_dict( 108 self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool 109 ): 110 self.customized_save_state_dict_called = True 111 return {"dummy": torch.ones(1)} 112 113 @torch.jit.export 114 def _load_from_state_dict( 115 self, 116 state_dict: Dict[str, torch.Tensor], 117 prefix: str, 118 local_metadata: Any, 119 strict: bool, 120 missing_keys: List[str], 121 unexpected_keys: List[str], 122 error_msgs: List[str], 123 ): 124 self.customized_load_state_dict_called = True 125 return 126 127 class ParentModule(torch.nn.Module): 128 def __init__(self) -> None: 129 super().__init__() 130 self.sub = CustomStateDictModule() 131 132 def forward(self, x): 133 return self.sub(x) 134 135 m1 = torch.jit.script(ParentModule()) 136 self.assertFalse(m1.sub.customized_save_state_dict_called) 137 state_dict = m1.state_dict() 138 self.assertTrue(m1.sub.customized_save_state_dict_called) 139 140 m2 = torch.jit.script(ParentModule()) 141 self.assertFalse(m2.sub.customized_load_state_dict_called) 142 m2.load_state_dict(state_dict) 143 self.assertTrue(m2.sub.customized_load_state_dict_called) 144