xref: /aosp_15_r20/external/pytorch/test/jit/test_hooks_modules.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3from typing import List, Tuple
4
5import torch
6
7
8class SubmoduleNoForwardInputs(torch.nn.Module):
9    def __init__(self, name):
10        super().__init__()
11        self.name = name
12
13    def forward(self):
14        assert self.name == "inner_mod_name"
15
16
17class ModuleNoForwardInputs(torch.nn.Module):
18    def __init__(self, name: str, submodule_name: str):
19        super().__init__()
20        self.name = name
21        self.submodule = SubmoduleNoForwardInputs(submodule_name)
22
23    def forward(self):
24        self.submodule()
25
26
27class SubmoduleForwardSingleInput(torch.nn.Module):
28    def __init__(self, name):
29        super().__init__()
30        self.name = name
31
32    def foo(self, input: str):
33        return input
34
35    def forward(self, input: str):
36        input = input + "_inner_mod"
37        input = self.foo(input)
38        return input
39
40
41class ModuleForwardSingleInput(torch.nn.Module):
42    def __init__(self, name: str, submodule_name: str):
43        super().__init__()
44        self.name = name
45        self.submodule = SubmoduleForwardSingleInput(submodule_name)
46
47    def forward(self, input: str):
48        input = input + "_outermod"
49        return self.submodule(input)
50
51
52class ModuleDirectforwardSubmodCall(torch.nn.Module):
53    def __init__(self, name: str, submodule_name: str):
54        super().__init__()
55        self.name = name
56        self.submodule = SubmoduleForwardSingleInput(submodule_name)
57
58    def forward(self, input: str):
59        input = input + "_outermod"
60        return self.submodule.forward(input)
61
62
63class SuboduleForwardMultipleInputs(torch.nn.Module):
64    def __init__(self, name):
65        super().__init__()
66        self.name = name
67
68    def forward(self, input1: List[str], input2: str):
69        input1.append(self.name)
70        output2 = input2 + "_"
71        return input1, output2
72
73
74class ModuleForwardMultipleInputs(torch.nn.Module):
75    def __init__(self, name: str, submodule_name: str):
76        super().__init__()
77        self.name = name
78        self.submodule = SuboduleForwardMultipleInputs(submodule_name)
79
80    def forward(self, input1: List[str], input2: str):
81        input1.append(self.name)
82        return self.submodule(input1, input2)
83
84
85class SubmoduleForwardTupleInput(torch.nn.Module):
86    def __init__(self, name):
87        super().__init__()
88        self.name = name
89
90    def forward(self, input: Tuple[int]):
91        input_access = input[0]
92        return (1,)
93
94
95class ModuleForwardTupleInput(torch.nn.Module):
96    def __init__(self, name: str, submodule_name: str):
97        super().__init__()
98        self.name = name
99        self.submodule = SubmoduleForwardTupleInput(submodule_name)
100
101    def forward(self, input: Tuple[int]):
102        input_access = input[0]
103        return self.submodule((1,))
104
105
106# Modules for JIT forward hook and pre-hooks python and cpp tests
107def create_module_no_forward_input():
108    # Use to test module level hooks with no forward input
109    m = ModuleNoForwardInputs("outer_mod_name", "inner_mod_name")
110
111    def pre_hook(self, input: Tuple[()]) -> None:
112        assert self.name == "outer_mod_name"
113
114    def forward_hook(self, input: Tuple[()], output: None):
115        assert self.name == "outer_mod_name"
116
117    m.register_forward_pre_hook(pre_hook)
118    m.register_forward_hook(forward_hook)
119
120    return m
121
122
123def create_submodule_no_forward_input():
124    # Use to test submodule level hooks with no forward input
125    m = ModuleNoForwardInputs("outer_mod_name", "inner_mod_name")
126
127    def pre_hook(self, input: Tuple[()]) -> None:
128        assert self.name == "inner_mod_name"
129
130    def forward_hook(self, input: Tuple[()], output: None):
131        assert self.name == "inner_mod_name"
132
133    m.submodule.register_forward_pre_hook(pre_hook)
134    m.submodule.register_forward_hook(forward_hook)
135
136    return m
137
138
139def create_module_forward_multiple_inputs():
140    # Use to test module level hooks with forward having multiple
141    # inputs and returns
142    m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name")
143
144    def pre_hook(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
145        assert self.name == "outer_mod_name"
146        assert input[0][0] == "a"
147        return ["pre_hook_override_name"], "pre_hook_override"
148
149    def forward_hook(self, input: Tuple[List[str], str], output: Tuple[List[str], str]):
150        assert self.name == "outer_mod_name"
151        assert input[0][0] == "pre_hook_override_name"
152        output2 = output[1] + "fh"
153        return output[0], output2
154
155    m.register_forward_pre_hook(pre_hook)
156    m.register_forward_hook(forward_hook)
157
158    return m
159
160
161def create_module_multiple_hooks_multiple_inputs():
162    # Use to test that module level hooks with multiple inputs execute
163    # in correct order and pass correct information between each other
164    m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name")
165
166    def pre_hook1(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
167        assert self.name == "outer_mod_name"
168        assert input[0][0] == "a"
169        return ["pre_hook_override_name"], "pre_hook_override"
170
171    def pre_hook2(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
172        assert self.name == "outer_mod_name"
173        assert input[0][0] == "pre_hook_override_name"
174        return ["pre_hook_override_name2"], "pre_hook_override"
175
176    def forward_hook1(
177        self, input: Tuple[List[str], str], output: Tuple[List[str], str]
178    ):
179        assert self.name == "outer_mod_name"
180        assert input[0][0] == "pre_hook_override_name2"
181        output2 = output[1] + "fh1"
182        return output[0], output2
183
184    def forward_hook2(
185        self, input: Tuple[List[str], str], output: Tuple[List[str], str]
186    ):
187        assert self.name == "outer_mod_name"
188        assert input[0][0] == "pre_hook_override_name2"
189        assert output[1] == "pre_hook_override_fh1"
190        output2 = output[1] + "_fh2"
191        return output[0], output2
192
193    m.register_forward_pre_hook(pre_hook1)
194    m.register_forward_pre_hook(pre_hook2)
195    m.register_forward_hook(forward_hook1)
196    m.register_forward_hook(forward_hook2)
197
198    return m
199
200
201def create_module_forward_single_input():
202    # Use to test module level hooks for forward with single input
203    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
204
205    def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
206        assert self.name == "outer_mod_name"
207        assert input[0] == "a"
208        return ("pre_hook_override_name",)
209
210    def forward_hook(self, input: Tuple[str], output: str):
211        assert self.name == "outer_mod_name"
212        assert input == ("pre_hook_override_name",)
213        output = output + "_fh"
214        return output
215
216    m.register_forward_pre_hook(pre_hook)
217    m.register_forward_hook(forward_hook)
218
219    return m
220
221
222def create_module_same_hook_repeated():
223    # Use to test module can run same hook multiple times
224    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
225
226    def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
227        assert self.name == "outer_mod_name"
228        input_change = input[0] + "_ph"
229        return (input_change,)
230
231    def forward_hook(self, input: Tuple[str], output: str):
232        assert self.name == "outer_mod_name"
233        assert input == ("a_ph_ph",)
234        output = output + "_fh"
235        return output
236
237    m.register_forward_pre_hook(pre_hook)
238    m.register_forward_pre_hook(pre_hook)
239    m.register_forward_hook(forward_hook)
240    m.register_forward_hook(forward_hook)
241
242    return m
243
244
245def create_module_hook_return_nothing():
246    # Use to test module level hooks that return nothing
247    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
248
249    def pre_hook(self, input: Tuple[str]) -> None:
250        assert self.name == "outer_mod_name"
251        assert input[0] == "a"
252
253    def forward_hook(self, input: Tuple[str], output: str):
254        assert self.name == "outer_mod_name"
255        assert input == ("a",)
256
257    m.register_forward_pre_hook(pre_hook)
258    m.register_forward_hook(forward_hook)
259
260    return m
261
262
263def create_module_multiple_hooks_single_input():
264    # Use to test that modules can run multiple hooks with single input
265    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
266
267    def pre_hook1(self, input: Tuple[str]) -> Tuple[str]:
268        assert self.name == "outer_mod_name"
269        assert input[0] == "a"
270        return ("pre_hook_override_name1",)
271
272    def pre_hook2(self, input: Tuple[str]) -> Tuple[str]:
273        assert self.name == "outer_mod_name"
274        assert input[0] == "pre_hook_override_name1"
275        return ("pre_hook_override_name2",)
276
277    def forward_hook1(self, input: Tuple[str], output: str):
278        assert self.name == "outer_mod_name"
279        assert input == ("pre_hook_override_name2",)
280        assert output == "pre_hook_override_name2_outermod_inner_mod"
281        output = output + "_fh1"
282        return output, output
283
284    def forward_hook2(self, input: Tuple[str], output: Tuple[str, str]):
285        assert self.name == "outer_mod_name"
286        assert input == ("pre_hook_override_name2",)
287        assert output[0] == "pre_hook_override_name2_outermod_inner_mod_fh1"
288        output = output[0] + "_fh2"
289        return output
290
291    m.register_forward_pre_hook(pre_hook1)
292    m.register_forward_pre_hook(pre_hook2)
293    m.register_forward_hook(forward_hook1)
294    m.register_forward_hook(forward_hook2)
295
296    return m
297
298
299def create_submodule_forward_multiple_inputs():
300    # Use to test that submodules can run hooks that have multiple forward inputs
301    m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name")
302
303    def pre_hook(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
304        assert self.name == "inner_mod_name"
305        assert input[0][1] == "outer_mod_name"
306        return ["pre_hook_override_name"], "pre_hook_override"
307
308    def forward_hook(self, input: Tuple[List[str], str], output: Tuple[List[str], str]):
309        assert self.name == "inner_mod_name"
310        assert input[0][0] == "pre_hook_override_name"
311        output2 = output[1] + "fh"
312        return output[0], output2
313
314    m.submodule.register_forward_pre_hook(pre_hook)
315    m.submodule.register_forward_hook(forward_hook)
316
317    return m
318
319
320def create_submodule_multiple_hooks_multiple_inputs():
321    # Use to test that submodules can run multiple hooks with multiple
322    # forward inputs
323    m = ModuleForwardMultipleInputs("outer_mod_name", "inner_mod_name")
324
325    def pre_hook1(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
326        assert self.name == "inner_mod_name"
327        assert input[1] == "no_pre_hook"
328        return ["pre_hook_override_name"], "pre_hook_override1"
329
330    def pre_hook2(self, input: Tuple[List[str], str]) -> Tuple[List[str], str]:
331        assert self.name == "inner_mod_name"
332        assert input[1] == "pre_hook_override1"
333        return ["pre_hook_override_name"], "pre_hook_override2"
334
335    def forward_hook1(
336        self, input: Tuple[List[str], str], output: Tuple[List[str], str]
337    ):
338        assert self.name == "inner_mod_name"
339        assert input[1] == "pre_hook_override2"
340        assert output[1] == "pre_hook_override2_"
341        output2 = output[1] + "fh1"
342        return output[0], output2, output2
343
344    def forward_hook2(
345        self, input: Tuple[List[str], str], output: Tuple[List[str], str, str]
346    ):
347        assert self.name == "inner_mod_name"
348        assert input[1] == "pre_hook_override2"
349        assert output[1] == "pre_hook_override2_fh1"
350        output2 = output[1] + "_fh2"
351        return output[0], output2
352
353    m.submodule.register_forward_pre_hook(pre_hook1)
354    m.submodule.register_forward_pre_hook(pre_hook2)
355    m.submodule.register_forward_hook(forward_hook1)
356    m.submodule.register_forward_hook(forward_hook2)
357
358    return m
359
360
361def create_submodule_forward_single_input():
362    # Use to test that submodules can run hooks with a single argument
363    # passed to forward
364    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
365
366    def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
367        assert self.name == "inner_mod_name"
368        assert input[0] == "a_outermod"
369        return ("pre_hook_override_name",)
370
371    def forward_hook(self, input: Tuple[str], output: str):
372        assert self.name == "inner_mod_name"
373        assert input == ("pre_hook_override_name",)
374        return output
375
376    m.submodule.register_forward_pre_hook(pre_hook)
377    m.submodule.register_forward_hook(forward_hook)
378
379    return m
380
381
382def create_submodule_to_call_directly_with_hooks():
383    # Use to test that submodules have their hooks invoked when called
384    # directly
385    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
386
387    def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
388        assert self.name == "inner_mod_name"
389        return ("pre_hook_override_name",)
390
391    def forward_hook(self, input: Tuple[str], output: str):
392        assert self.name == "inner_mod_name"
393        assert input == ("pre_hook_override_name",)
394        return output + "_fh"
395
396    m.submodule.register_forward_pre_hook(pre_hook)
397    m.submodule.register_forward_hook(forward_hook)
398
399    return m
400
401
402def create_submodule_same_hook_repeated():
403    # Use to test that submodules can run same hooks multiple times
404    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
405
406    def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
407        assert self.name == "inner_mod_name"
408        changed = input[0] + "_ph"
409        return (changed,)
410
411    def forward_hook(self, input: Tuple[str], output: str):
412        assert self.name == "inner_mod_name"
413        assert input == ("a_outermod_ph_ph",)
414        return output + "_fh"
415
416    m.submodule.register_forward_pre_hook(pre_hook)
417    m.submodule.register_forward_pre_hook(pre_hook)
418    m.submodule.register_forward_hook(forward_hook)
419    m.submodule.register_forward_hook(forward_hook)
420
421    return m
422
423
424def create_submodule_hook_return_nothing():
425    # Use to test that submodules can run hooks that return nothing
426    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
427
428    def pre_hook(self, input: Tuple[str]) -> None:
429        assert self.name == "inner_mod_name"
430        assert input[0] == "a_outermod"
431
432    def forward_hook(self, input: Tuple[str], output: str):
433        assert self.name == "inner_mod_name"
434        assert input == ("a_outermod",)
435
436    m.submodule.register_forward_pre_hook(pre_hook)
437    m.submodule.register_forward_hook(forward_hook)
438
439    return m
440
441
442def create_submodule_multiple_hooks_single_input():
443    # Use to test that submodules can run multiple hooks that have a single input
444    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
445
446    def pre_hook1(self, input: Tuple[str]) -> Tuple[str]:
447        assert self.name == "inner_mod_name"
448        assert input[0] == "a_outermod"
449        return ("pre_hook_override_name",)
450
451    def pre_hook2(self, input: Tuple[str]) -> Tuple[str]:
452        assert self.name == "inner_mod_name"
453        assert input[0] == "pre_hook_override_name"
454        return ("pre_hook_override_name2",)
455
456    def forward_hook1(self, input: Tuple[str], output: str):
457        assert self.name == "inner_mod_name"
458        assert input == ("pre_hook_override_name2",)
459        assert output == "pre_hook_override_name2_inner_mod"
460        return output + "_fwh1"
461
462    def forward_hook2(self, input: Tuple[str], output: str):
463        assert self.name == "inner_mod_name"
464        assert input == ("pre_hook_override_name2",)
465        assert output == "pre_hook_override_name2_inner_mod_fwh1"
466        return output
467
468    m.submodule.register_forward_pre_hook(pre_hook1)
469    m.submodule.register_forward_pre_hook(pre_hook2)
470    m.submodule.register_forward_hook(forward_hook1)
471    m.submodule.register_forward_hook(forward_hook2)
472
473    return m
474
475
476def create_forward_tuple_input():
477    # Use to test case where forward is passed a single tuple for input.
478    # This is different because eager always wraps pre-hook return arguments
479    # in a tuple when the returned pre-hook result isn't a tuple
480    # (to allow the result to be passed to another pre-hook if needed).
481    # The eager behavior doesn't wrap the single tuple input pre-hook return in a
482    # tuple as it should. To get consistent behavior between single tuple inputs and
483    # the rest of the possible forward inputs, pre-hooks need to
484    # wrap single tuple inputs returns in another tuple. This is
485    # enforced by the schema checker.
486    m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name")
487
488    def pre_hook_outermod(self, input: Tuple[Tuple[int]]) -> Tuple[Tuple[int]]:
489        # 'return (11,)' doesn't work with eager, inner tuple lost
490        return ((11,),)
491
492    def pre_hook_innermod(self, input: Tuple[Tuple[int]]) -> Tuple[Tuple[int]]:
493        # 'return (22,)' doesn't work with eager, inner tuple lost
494        return ((22,),)
495
496    def forward_hook_outermod(self, input: Tuple[Tuple[int]], output: int):
497        return (11,)
498
499    def forward_hook_innermod(self, input: Tuple[Tuple[int]], output: Tuple[int]):
500        return 22
501
502    m.register_forward_pre_hook(pre_hook_outermod)
503    m.submodule.register_forward_pre_hook(pre_hook_innermod)
504    m.register_forward_hook(forward_hook_outermod)
505    m.submodule.register_forward_hook(forward_hook_innermod)
506
507    return m
508
509
510def create_submodule_forward_single_input_return_not_tupled():
511    # Use to check that submodules can return modified inputs
512    # that aren't wrapped in a tuple (to match eager behavior)
513    m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
514
515    def pre_hook(self, input: Tuple[str]) -> str:
516        assert self.name == "inner_mod_name"
517        assert input[0] == "a_outermod"
518        # return is wrapped in tuple in other test cases
519        return "pre_hook_override_name"
520
521    def forward_hook(self, input: Tuple[str], output: str):
522        assert self.name == "inner_mod_name"
523        assert input == ("pre_hook_override_name",)
524        output = output + "_fh"
525        return output
526
527    m.submodule.register_forward_pre_hook(pre_hook)
528    m.submodule.register_forward_hook(forward_hook)
529
530    return m
531