1# Owner(s): ["oncall: jit"] 2 3from typing import List, Tuple 4 5import torch 6 7 8class SubmoduleNoForwardInputs(torch.nn.Module): 9 def __init__(self, name): 10 super().__init__() 11 self.name = name 12 13 def forward(self): 14 assert self.name == "inner_mod_name" 15 16 17class ModuleNoForwardInputs(torch.nn.Module): 18 def __init__(self, name: str, submodule_name: str): 19 super().__init__() 20 self.name = name 21 self.submodule = SubmoduleNoForwardInputs(submodule_name) 22 23 def forward(self): 24 self.submodule() 25 26 27class SubmoduleForwardSingleInput(torch.nn.Module): 28 def __init__(self, name): 29 super().__init__() 30 self.name = name 31 32 def foo(self, input: str): 33 return input 34 35 def forward(self, input: str): 36 input = input + "_inner_mod" 37 input = self.foo(input) 38 return input 39 40 41class ModuleForwardSingleInput(torch.nn.Module): 42 def __init__(self, name: str, submodule_name: str): 43 super().__init__() 44 self.name = name 45 self.submodule = SubmoduleForwardSingleInput(submodule_name) 46 47 def forward(self, input: str): 48 input = input + "_outermod" 49 return self.submodule(input) 50 51 52class ModuleDirectforwardSubmodCall(torch.nn.Module): 53 def __init__(self, name: str, submodule_name: str): 54 super().__init__() 55 self.name = name 56 self.submodule = SubmoduleForwardSingleInput(submodule_name) 57 58 def forward(self, input: str): 59 input = input + "_outermod" 60 return self.submodule.forward(input) 61 62 63class SuboduleForwardMultipleInputs(torch.nn.Module): 64 def __init__(self, name): 65 super().__init__() 66 self.name = name 67 68 def forward(self, input1: List[str], input2: str): 69 input1.append(self.name) 70 output2 = input2 + "_" 71 return input1, output2 72 73 74class ModuleForwardMultipleInputs(torch.nn.Module): 75 def __init__(self, name: str, submodule_name: str): 76 super().__init__() 77 self.name = name 78 self.submodule = SuboduleForwardMultipleInputs(submodule_name) 79 80 def forward(self, input1: List[str], input2: str): 81 input1.append(self.name) 82 return self.submodule(input1, input2) 83 84 85class SubmoduleForwardTupleInput(torch.nn.Module): 86 def __init__(self, name): 87 super().__init__() 88 self.name = name 89 90 def forward(self, input: Tuple[int]): 91 input_access = input[0] 92 return (1,) 93 94 95class ModuleForwardTupleInput(torch.nn.Module): 96 def __init__(self, name: str, submodule_name: str): 97 super().__init__() 98 self.name = name 99 self.submodule = SubmoduleForwardTupleInput(submodule_name) 100 101 def forward(self, input: Tuple[int]): 102 input_access = input[0] 103 return self.submodule((1,)) 104 105 106# Modules for JIT forward hook and pre-hooks python and cpp tests 107def create_module_no_forward_input(): 108 # Use to test module level hooks with no forward input 109 m = ModuleNoForwardInputs("outer_mod_name", "inner_mod_name") 110 111 def pre_hook(self, input: Tuple[()]) -> None: 112 assert self.name == "outer_mod_name" 113 114 def forward_hook(self, input: Tuple[()], output: None): 115 assert self.name == "outer_mod_name" 116 117 m.register_forward_pre_hook(pre_hook) 118 m.register_forward_hook(forward_hook) 119 120 return m 121 122 123def create_submodule_no_forward_input(): 124 # Use to test submodule level hooks with no forward input 125 m = ModuleNoForwardInputs("outer_mod_name", "inner_mod_name") 126 127 def pre_hook(self, input: Tuple[()]) -> None: 128 assert self.name == "inner_mod_name" 129 130 def forward_hook(self, input: Tuple[()], output: None): 131 assert self.name == "inner_mod_name" 132 133 m.submodule.register_forward_pre_hook(pre_hook) 134 m.submodule.register_forward_hook(forward_hook) 135 136 return m 137 138 139def create_module_forward_multiple_inputs(): 140 # Use to test module level hooks with forward having multiple 141 # inputs and returns 142 m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name") 143 144 def pre_hook(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 145 assert self.name == "outer_mod_name" 146 assert input[0][0] == "a" 147 return ["pre_hook_override_name"], "pre_hook_override" 148 149 def forward_hook(self, input: Tuple[List[str], str], output: Tuple[List[str], str]): 150 assert self.name == "outer_mod_name" 151 assert input[0][0] == "pre_hook_override_name" 152 output2 = output[1] + "fh" 153 return output[0], output2 154 155 m.register_forward_pre_hook(pre_hook) 156 m.register_forward_hook(forward_hook) 157 158 return m 159 160 161def create_module_multiple_hooks_multiple_inputs(): 162 # Use to test that module level hooks with multiple inputs execute 163 # in correct order and pass correct information between each other 164 m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name") 165 166 def pre_hook1(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 167 assert self.name == "outer_mod_name" 168 assert input[0][0] == "a" 169 return ["pre_hook_override_name"], "pre_hook_override" 170 171 def pre_hook2(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 172 assert self.name == "outer_mod_name" 173 assert input[0][0] == "pre_hook_override_name" 174 return ["pre_hook_override_name2"], "pre_hook_override" 175 176 def forward_hook1( 177 self, input: Tuple[List[str], str], output: Tuple[List[str], str] 178 ): 179 assert self.name == "outer_mod_name" 180 assert input[0][0] == "pre_hook_override_name2" 181 output2 = output[1] + "fh1" 182 return output[0], output2 183 184 def forward_hook2( 185 self, input: Tuple[List[str], str], output: Tuple[List[str], str] 186 ): 187 assert self.name == "outer_mod_name" 188 assert input[0][0] == "pre_hook_override_name2" 189 assert output[1] == "pre_hook_override_fh1" 190 output2 = output[1] + "_fh2" 191 return output[0], output2 192 193 m.register_forward_pre_hook(pre_hook1) 194 m.register_forward_pre_hook(pre_hook2) 195 m.register_forward_hook(forward_hook1) 196 m.register_forward_hook(forward_hook2) 197 198 return m 199 200 201def create_module_forward_single_input(): 202 # Use to test module level hooks for forward with single input 203 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 204 205 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 206 assert self.name == "outer_mod_name" 207 assert input[0] == "a" 208 return ("pre_hook_override_name",) 209 210 def forward_hook(self, input: Tuple[str], output: str): 211 assert self.name == "outer_mod_name" 212 assert input == ("pre_hook_override_name",) 213 output = output + "_fh" 214 return output 215 216 m.register_forward_pre_hook(pre_hook) 217 m.register_forward_hook(forward_hook) 218 219 return m 220 221 222def create_module_same_hook_repeated(): 223 # Use to test module can run same hook multiple times 224 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 225 226 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 227 assert self.name == "outer_mod_name" 228 input_change = input[0] + "_ph" 229 return (input_change,) 230 231 def forward_hook(self, input: Tuple[str], output: str): 232 assert self.name == "outer_mod_name" 233 assert input == ("a_ph_ph",) 234 output = output + "_fh" 235 return output 236 237 m.register_forward_pre_hook(pre_hook) 238 m.register_forward_pre_hook(pre_hook) 239 m.register_forward_hook(forward_hook) 240 m.register_forward_hook(forward_hook) 241 242 return m 243 244 245def create_module_hook_return_nothing(): 246 # Use to test module level hooks that return nothing 247 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 248 249 def pre_hook(self, input: Tuple[str]) -> None: 250 assert self.name == "outer_mod_name" 251 assert input[0] == "a" 252 253 def forward_hook(self, input: Tuple[str], output: str): 254 assert self.name == "outer_mod_name" 255 assert input == ("a",) 256 257 m.register_forward_pre_hook(pre_hook) 258 m.register_forward_hook(forward_hook) 259 260 return m 261 262 263def create_module_multiple_hooks_single_input(): 264 # Use to test that modules can run multiple hooks with single input 265 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 266 267 def pre_hook1(self, input: Tuple[str]) -> Tuple[str]: 268 assert self.name == "outer_mod_name" 269 assert input[0] == "a" 270 return ("pre_hook_override_name1",) 271 272 def pre_hook2(self, input: Tuple[str]) -> Tuple[str]: 273 assert self.name == "outer_mod_name" 274 assert input[0] == "pre_hook_override_name1" 275 return ("pre_hook_override_name2",) 276 277 def forward_hook1(self, input: Tuple[str], output: str): 278 assert self.name == "outer_mod_name" 279 assert input == ("pre_hook_override_name2",) 280 assert output == "pre_hook_override_name2_outermod_inner_mod" 281 output = output + "_fh1" 282 return output, output 283 284 def forward_hook2(self, input: Tuple[str], output: Tuple[str, str]): 285 assert self.name == "outer_mod_name" 286 assert input == ("pre_hook_override_name2",) 287 assert output[0] == "pre_hook_override_name2_outermod_inner_mod_fh1" 288 output = output[0] + "_fh2" 289 return output 290 291 m.register_forward_pre_hook(pre_hook1) 292 m.register_forward_pre_hook(pre_hook2) 293 m.register_forward_hook(forward_hook1) 294 m.register_forward_hook(forward_hook2) 295 296 return m 297 298 299def create_submodule_forward_multiple_inputs(): 300 # Use to test that submodules can run hooks that have multiple forward inputs 301 m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name") 302 303 def pre_hook(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 304 assert self.name == "inner_mod_name" 305 assert input[0][1] == "outer_mod_name" 306 return ["pre_hook_override_name"], "pre_hook_override" 307 308 def forward_hook(self, input: Tuple[List[str], str], output: Tuple[List[str], str]): 309 assert self.name == "inner_mod_name" 310 assert input[0][0] == "pre_hook_override_name" 311 output2 = output[1] + "fh" 312 return output[0], output2 313 314 m.submodule.register_forward_pre_hook(pre_hook) 315 m.submodule.register_forward_hook(forward_hook) 316 317 return m 318 319 320def create_submodule_multiple_hooks_multiple_inputs(): 321 # Use to test that submodules can run multiple hooks with multiple 322 # forward inputs 323 m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name") 324 325 def pre_hook1(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 326 assert self.name == "inner_mod_name" 327 assert input[1] == "no_pre_hook" 328 return ["pre_hook_override_name"], "pre_hook_override1" 329 330 def pre_hook2(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]: 331 assert self.name == "inner_mod_name" 332 assert input[1] == "pre_hook_override1" 333 return ["pre_hook_override_name"], "pre_hook_override2" 334 335 def forward_hook1( 336 self, input: Tuple[List[str], str], output: Tuple[List[str], str] 337 ): 338 assert self.name == "inner_mod_name" 339 assert input[1] == "pre_hook_override2" 340 assert output[1] == "pre_hook_override2_" 341 output2 = output[1] + "fh1" 342 return output[0], output2, output2 343 344 def forward_hook2( 345 self, input: Tuple[List[str], str], output: Tuple[List[str], str, str] 346 ): 347 assert self.name == "inner_mod_name" 348 assert input[1] == "pre_hook_override2" 349 assert output[1] == "pre_hook_override2_fh1" 350 output2 = output[1] + "_fh2" 351 return output[0], output2 352 353 m.submodule.register_forward_pre_hook(pre_hook1) 354 m.submodule.register_forward_pre_hook(pre_hook2) 355 m.submodule.register_forward_hook(forward_hook1) 356 m.submodule.register_forward_hook(forward_hook2) 357 358 return m 359 360 361def create_submodule_forward_single_input(): 362 # Use to test that submodules can run hooks with a single argument 363 # passed to forward 364 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 365 366 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 367 assert self.name == "inner_mod_name" 368 assert input[0] == "a_outermod" 369 return ("pre_hook_override_name",) 370 371 def forward_hook(self, input: Tuple[str], output: str): 372 assert self.name == "inner_mod_name" 373 assert input == ("pre_hook_override_name",) 374 return output 375 376 m.submodule.register_forward_pre_hook(pre_hook) 377 m.submodule.register_forward_hook(forward_hook) 378 379 return m 380 381 382def create_submodule_to_call_directly_with_hooks(): 383 # Use to test that submodules have their hooks invoked when called 384 # directly 385 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 386 387 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 388 assert self.name == "inner_mod_name" 389 return ("pre_hook_override_name",) 390 391 def forward_hook(self, input: Tuple[str], output: str): 392 assert self.name == "inner_mod_name" 393 assert input == ("pre_hook_override_name",) 394 return output + "_fh" 395 396 m.submodule.register_forward_pre_hook(pre_hook) 397 m.submodule.register_forward_hook(forward_hook) 398 399 return m 400 401 402def create_submodule_same_hook_repeated(): 403 # Use to test that submodules can run same hooks multiple times 404 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 405 406 def pre_hook(self, input: Tuple[str]) -> Tuple[str]: 407 assert self.name == "inner_mod_name" 408 changed = input[0] + "_ph" 409 return (changed,) 410 411 def forward_hook(self, input: Tuple[str], output: str): 412 assert self.name == "inner_mod_name" 413 assert input == ("a_outermod_ph_ph",) 414 return output + "_fh" 415 416 m.submodule.register_forward_pre_hook(pre_hook) 417 m.submodule.register_forward_pre_hook(pre_hook) 418 m.submodule.register_forward_hook(forward_hook) 419 m.submodule.register_forward_hook(forward_hook) 420 421 return m 422 423 424def create_submodule_hook_return_nothing(): 425 # Use to test that submodules can run hooks that return nothing 426 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 427 428 def pre_hook(self, input: Tuple[str]) -> None: 429 assert self.name == "inner_mod_name" 430 assert input[0] == "a_outermod" 431 432 def forward_hook(self, input: Tuple[str], output: str): 433 assert self.name == "inner_mod_name" 434 assert input == ("a_outermod",) 435 436 m.submodule.register_forward_pre_hook(pre_hook) 437 m.submodule.register_forward_hook(forward_hook) 438 439 return m 440 441 442def create_submodule_multiple_hooks_single_input(): 443 # Use to test that submodules can run multiple hooks that have a single input 444 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 445 446 def pre_hook1(self, input: Tuple[str]) -> Tuple[str]: 447 assert self.name == "inner_mod_name" 448 assert input[0] == "a_outermod" 449 return ("pre_hook_override_name",) 450 451 def pre_hook2(self, input: Tuple[str]) -> Tuple[str]: 452 assert self.name == "inner_mod_name" 453 assert input[0] == "pre_hook_override_name" 454 return ("pre_hook_override_name2",) 455 456 def forward_hook1(self, input: Tuple[str], output: str): 457 assert self.name == "inner_mod_name" 458 assert input == ("pre_hook_override_name2",) 459 assert output == "pre_hook_override_name2_inner_mod" 460 return output + "_fwh1" 461 462 def forward_hook2(self, input: Tuple[str], output: str): 463 assert self.name == "inner_mod_name" 464 assert input == ("pre_hook_override_name2",) 465 assert output == "pre_hook_override_name2_inner_mod_fwh1" 466 return output 467 468 m.submodule.register_forward_pre_hook(pre_hook1) 469 m.submodule.register_forward_pre_hook(pre_hook2) 470 m.submodule.register_forward_hook(forward_hook1) 471 m.submodule.register_forward_hook(forward_hook2) 472 473 return m 474 475 476def create_forward_tuple_input(): 477 # Use to test case where forward is passed a single tuple for input. 478 # This is different because eager always wraps pre-hook return arguments 479 # in a tuple when the returned pre-hook result isn't a tuple 480 # (to allow the result to be passed to another pre-hook if needed). 481 # The eager behavior doesn't wrap the single tuple input pre-hook return in a 482 # tuple as it should. To get consistent behavior between single tuple inputs and 483 # the rest of the possible forward inputs, pre-hooks need to 484 # wrap single tuple inputs returns in another tuple. This is 485 # enforced by the schema checker. 486 m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name") 487 488 def pre_hook_outermod(self, input: Tuple[Tuple[int]]) -> Tuple[Tuple[int]]: 489 # 'return (11,)' doesn't work with eager, inner tuple lost 490 return ((11,),) 491 492 def pre_hook_innermod(self, input: Tuple[Tuple[int]]) -> Tuple[Tuple[int]]: 493 # 'return (22,)' doesn't work with eager, inner tuple lost 494 return ((22,),) 495 496 def forward_hook_outermod(self, input: Tuple[Tuple[int]], output: int): 497 return (11,) 498 499 def forward_hook_innermod(self, input: Tuple[Tuple[int]], output: Tuple[int]): 500 return 22 501 502 m.register_forward_pre_hook(pre_hook_outermod) 503 m.submodule.register_forward_pre_hook(pre_hook_innermod) 504 m.register_forward_hook(forward_hook_outermod) 505 m.submodule.register_forward_hook(forward_hook_innermod) 506 507 return m 508 509 510def create_submodule_forward_single_input_return_not_tupled(): 511 # Use to check that submodules can return modified inputs 512 # that aren't wrapped in a tuple (to match eager behavior) 513 m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name") 514 515 def pre_hook(self, input: Tuple[str]) -> str: 516 assert self.name == "inner_mod_name" 517 assert input[0] == "a_outermod" 518 # return is wrapped in tuple in other test cases 519 return "pre_hook_override_name" 520 521 def forward_hook(self, input: Tuple[str], output: str): 522 assert self.name == "inner_mod_name" 523 assert input == ("pre_hook_override_name",) 524 output = output + "_fh" 525 return output 526 527 m.submodule.register_forward_pre_hook(pre_hook) 528 m.submodule.register_forward_hook(forward_hook) 529 530 return m 531