1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6from typing import Tuple 7 8import torch 9from jit.test_hooks_modules import ( 10 create_forward_tuple_input, 11 create_module_forward_multiple_inputs, 12 create_module_forward_single_input, 13 create_module_hook_return_nothing, 14 create_module_multiple_hooks_multiple_inputs, 15 create_module_multiple_hooks_single_input, 16 create_module_no_forward_input, 17 create_module_same_hook_repeated, 18 create_submodule_forward_multiple_inputs, 19 create_submodule_forward_single_input, 20 create_submodule_forward_single_input_return_not_tupled, 21 create_submodule_hook_return_nothing, 22 create_submodule_multiple_hooks_multiple_inputs, 23 create_submodule_multiple_hooks_single_input, 24 create_submodule_no_forward_input, 25 create_submodule_same_hook_repeated, 26 create_submodule_to_call_directly_with_hooks, 27 ModuleDirectforwardSubmodCall, 28 ModuleForwardSingleInput, 29 ModuleForwardTupleInput, 30) 31 32 33# Make the helper files in test/ importable 34pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 35sys.path.append(pytorch_test_dir) 36from torch.testing._internal.jit_utils import JitTestCase 37 38 39if __name__ == "__main__": 40 raise RuntimeError( 41 "This test file is not meant to be run directly, use:\n\n" 42 "\tpython test/test_jit.py TESTNAME\n\n" 43 "instead." 44 ) 45 46 47# Tests for JIT forward hooks and pre-hooks 48class TestHooks(JitTestCase): 49 def test_module_no_forward_input(self): 50 self.checkModule(create_module_no_forward_input(), ()) 51 52 def test_submodule_no_forward_input(self): 53 self.checkModule(create_submodule_no_forward_input(), ()) 54 55 def test_module_forward_multiple_inputs(self): 56 self.checkModule( 57 create_module_forward_multiple_inputs(), (["a"], "no_pre_hook") 58 ) 59 60 def test_module_multiple_hooks_multiple_inputs(self): 61 self.checkModule( 62 create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook") 63 ) 64 65 def test_module_forward_single_input(self): 66 self.checkModule(create_module_forward_single_input(), ("a",)) 67 68 def test_module_same_hook_repeated(self): 69 self.checkModule(create_module_same_hook_repeated(), ("a",)) 70 71 def test_module_hook_return_nothing(self): 72 self.checkModule(create_module_hook_return_nothing(), ("a",)) 73 74 def test_module_multiple_hooks_single_input(self): 75 self.checkModule(create_module_multiple_hooks_single_input(), ("a",)) 76 77 def test_submodule_forward_multiple_inputs(self): 78 self.checkModule( 79 create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook") 80 ) 81 82 def test_submodule_multiple_hooks_multiple_inputs(self): 83 self.checkModule( 84 create_submodule_multiple_hooks_multiple_inputs(), 85 (["a"], "no_pre_hook"), 86 ) 87 88 def test_submodule_forward_single_input(self): 89 self.checkModule(create_submodule_forward_single_input(), ("a",)) 90 91 def test_submodule_called_directly_with_hooks(self): 92 module = create_submodule_to_call_directly_with_hooks() 93 module_scripted = torch.jit.script(module) 94 95 submodule = module.submodule 96 scripted_submodule = module_scripted.submodule 97 98 self.assertEqual(submodule("a"), scripted_submodule("a")) 99 100 def test_submodule_same_hook_repeated(self): 101 self.checkModule(create_submodule_same_hook_repeated(), ("a",)) 102 103 def test_submodule_hook_return_nothing(self): 104 self.checkModule(create_submodule_hook_return_nothing(), ("a",)) 105 106 def test_submodule_multiple_hooks_single_input(self): 107 self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"])) 108 109 def test_forward_tuple_input(self): 110 self.checkModule(create_forward_tuple_input(), ((3,),)) 111 112 def test_submodule_forward_single_input_return_not_tupled(self): 113 self.checkModule( 114 create_submodule_forward_single_input_return_not_tupled(), ("a",) 115 ) 116 117 def test_hook_method_name_collision(self): 118 # Hooks can't have the same name as methods. 119 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 120 121 def foo(self, input: Tuple[str]) -> Tuple[str]: 122 assert self.name == "inner_mod_name" 123 assert input[0] == "a_outermod" 124 return ("pre_hook_override_name",) 125 126 m.submodule.register_forward_pre_hook(foo) 127 128 with self.assertRaisesRegex( 129 RuntimeError, 130 "Can't define hook: foo on class: .+ " 131 "because a method or hook with that name already exists.", 132 ): 133 torch.jit.script(m) 134 135 def test_hook_hook_name_collision(self): 136 # Test edge case of two hooks sharing name but not python definition 137 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 138 139 def prehook(self, input: Tuple[str]) -> Tuple[str]: 140 return "This is the first hook" 141 142 m.submodule.register_forward_pre_hook(prehook) 143 144 def prehook(self, input: Tuple[str]) -> Tuple[str]: 145 return "This is the second hook" 146 147 m.submodule.register_forward_pre_hook(prehook) 148 149 with self.assertRaisesRegex( 150 RuntimeError, 151 "Pre-hook '.+' on .+ has at least two different python " 152 "definitions. Please use unique names for all hooks.", 153 ): 154 torch.jit.script(m) 155 156 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 157 158 def hook(self, input: Tuple[str], output: str): 159 return "This is the first hook" 160 161 m.submodule.register_forward_hook(hook) 162 163 def hook(self, input: Tuple[str]): 164 return "This is the second hook" 165 166 m.submodule.register_forward_hook(hook) 167 168 with self.assertRaisesRegex( 169 RuntimeError, 170 "Hook '.+' on .+ has at least two different python " 171 "definitions. Please use unique names for all hooks.", 172 ): 173 torch.jit.script(m) 174 175 def test_module_direct_forward_invocation(self): 176 # Test that hooks are only invoked when the module is 177 # called directly and not when forward is called. 178 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 179 180 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 181 return ("pre_hook_override_name",) 182 183 def forward_hook(self, input: Tuple[str], output: str): 184 assert self.name == "outer_mod_name" 185 assert input == ("pre_hook_override_name",) 186 output = output + "_fh" 187 return output 188 189 m.register_forward_pre_hook(pre_hook) 190 m.register_forward_hook(forward_hook) 191 192 m_scripted = torch.jit.script(m) 193 194 self.assertEqual(m.forward("a"), m_scripted.forward("a")) 195 self.assertNotEqual(m_scripted("a"), m_scripted.forward("a")) 196 197 def test_submodule_direct_forward_invocation(self): 198 m_submod_forward_call = ModuleDirectforwardSubmodCall( 199 "outer_mod_name", "inner_mod_name" 200 ) 201 m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 202 203 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 204 return ("pre_hook_override_name",) 205 206 def forward_hook(self, input: Tuple[str], output: str): 207 assert input == ("pre_hook_override_name",) 208 return output + "_fh" 209 210 m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook) 211 m_submod_forward_call.submodule.register_forward_hook(forward_hook) 212 m_submod_call.submodule.register_forward_pre_hook(pre_hook) 213 m_submod_call.submodule.register_forward_hook(forward_hook) 214 215 m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call) 216 m_submod_call_scripted = torch.jit.script(m_submod_call) 217 218 self.assertEqual( 219 m_submod_forward_call_scripted("a"), m_submod_forward_call("a") 220 ) 221 self.assertNotEqual( 222 m_submod_forward_call_scripted("a"), m_submod_call_scripted("a") 223 ) 224 225 # TODO: add this test back once figured out how to print error msg 226 @unittest.skip 227 def test_hook_compilation_hint(self): 228 # Tests if hook error message is printed out if erroring after schema check. 229 # Useful for when user is scripting hooks while not aware of it. 230 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 231 232 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 233 assert self.name == "outer_mod_name" 234 assert input[4] == "a" # out of bounds tuple range 235 return ("pre_hook_override_name",) 236 237 m.register_forward_pre_hook(pre_hook) 238 239 with self.assertRaisesRegex( 240 RuntimeError, 241 "This error occurred while scripting the forward pre-hook 'pre_hook'", 242 ): 243 torch.jit.script(m) 244 245 def test_wrong_pre_hook_signatures(self): 246 # correct signature: pre_hook_c(self, input: Tuple[str]) 247 def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]: 248 return ("hello",) 249 250 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 251 m.register_forward_pre_hook(pre_hook_wrong_input1) 252 253 with self.assertRaisesRegex( 254 RuntimeError, 255 "has the wrong inner types for the input tuple argument", 256 ): 257 torch.jit.script(m) 258 259 def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]: 260 return ("hello",) 261 262 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 263 m.register_forward_pre_hook(pre_hook_wrong_input2) 264 265 with self.assertRaisesRegex( 266 RuntimeError, 267 "was expected to only have exactly 2 inputs but it had 3 inputs", 268 ): 269 torch.jit.script(m) 270 271 def pre_hook_wrong_input3(self, input: int) -> Tuple[str]: 272 return ("hello",) 273 274 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 275 m.register_forward_pre_hook(pre_hook_wrong_input3) 276 277 with self.assertRaisesRegex( 278 RuntimeError, 279 "expected the input argument to be typed as a Tuple but" 280 " found type: 'int' instead", 281 ): 282 torch.jit.script(m) 283 284 def pre_hook_wrong_output(self, input: Tuple[str]) -> int: 285 return 1 # expecting Tuple[str], str, or None 286 287 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 288 m.register_forward_pre_hook(pre_hook_wrong_output) 289 290 with self.assertRaisesRegex( 291 RuntimeError, 292 "returned the wrong type of: 'int'", 293 ): 294 torch.jit.script(m) 295 296 def pre_hook_no_output_annotation(self, input: Tuple[str]): 297 return 1 # expecting Tuple[str], str, or None 298 299 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 300 m.register_forward_pre_hook(pre_hook_no_output_annotation) 301 302 with self.assertRaisesRegex( 303 RuntimeError, 304 "is missing a return annotation. Return annotations" 305 " are required, please add one.", 306 ): 307 torch.jit.script(m) 308 309 def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]: 310 return (11,) # doesn't work with eager, inner tuple lost 311 312 m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name") 313 m.register_forward_pre_hook(pre_hook_wrong_tuple_return) 314 315 with self.assertRaisesRegex( 316 RuntimeError, 317 "When forward has a single tuple input argument, " 318 "the return needs to be 'None' or a nested tuple containing " 319 r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'", 320 ): 321 torch.jit.script(m) 322 323 def test_wrong_hook_signatures(self): 324 # correct signature: 325 # def forward_hook(self, input: Tuple[str], output: str) 326 def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str): 327 return output 328 329 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 330 m.register_forward_hook(forward_hook_wrong_input1) 331 332 with self.assertRaisesRegex( 333 RuntimeError, 334 "has the wrong number of contained types for the " 335 r"input argument's Tuple. Received type: 'Tuple\[str, str\]'", 336 ): 337 torch.jit.script(m) 338 339 def forward_hook_wrong_input2(self, input: str, output: str): 340 return output 341 342 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 343 m.register_forward_hook(forward_hook_wrong_input2) 344 345 with self.assertRaisesRegex( 346 RuntimeError, 347 "expected the input argument to be typed as a Tuple " 348 "but found type: 'str' instead.", 349 ): 350 torch.jit.script(m) 351 352 def forward_hook_wrong_input3(self, input: Tuple[None], output: str): 353 return output 354 355 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 356 m.register_forward_hook(forward_hook_wrong_input3) 357 358 with self.assertRaisesRegex( 359 RuntimeError, 360 "has the wrong inner types for the input tuple" 361 r" argument. Received type: 'Tuple\[NoneType\]'", 362 ): 363 torch.jit.script(m) 364 365 def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]): 366 return output 367 368 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 369 m.register_forward_hook(forward_hook_wrong_output) 370 371 with self.assertRaisesRegex( 372 RuntimeError, 373 "has the wrong type for the output argument. Received" 374 r" type: 'Tuple\[str\]'. Expected type: 'str'", 375 ): 376 torch.jit.script(m) 377 378 def forward_hook_correct(self, input: Tuple[str], output: str): 379 return (output,) 380 381 def forward_hook_wrong_output_from_prev_hook( 382 self, input: Tuple[str], output: str 383 ): 384 return output 385 386 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 387 m.register_forward_hook(forward_hook_correct) 388 m.register_forward_hook(forward_hook_wrong_output_from_prev_hook) 389 390 with self.assertRaisesRegex( 391 RuntimeError, 392 "has the wrong type for the output argument. " 393 r"Received type: 'str'. Expected type: 'Tuple\[str\]'", 394 ): 395 torch.jit.script(m) 396