xref: /aosp_15_r20/external/pytorch/test/jit/test_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6from typing import Tuple
7
8import torch
9from jit.test_hooks_modules import (
10    create_forward_tuple_input,
11    create_module_forward_multiple_inputs,
12    create_module_forward_single_input,
13    create_module_hook_return_nothing,
14    create_module_multiple_hooks_multiple_inputs,
15    create_module_multiple_hooks_single_input,
16    create_module_no_forward_input,
17    create_module_same_hook_repeated,
18    create_submodule_forward_multiple_inputs,
19    create_submodule_forward_single_input,
20    create_submodule_forward_single_input_return_not_tupled,
21    create_submodule_hook_return_nothing,
22    create_submodule_multiple_hooks_multiple_inputs,
23    create_submodule_multiple_hooks_single_input,
24    create_submodule_no_forward_input,
25    create_submodule_same_hook_repeated,
26    create_submodule_to_call_directly_with_hooks,
27    ModuleDirectforwardSubmodCall,
28    ModuleForwardSingleInput,
29    ModuleForwardTupleInput,
30)
31
32
33# Make the helper files in test/ importable
34pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
35sys.path.append(pytorch_test_dir)
36from torch.testing._internal.jit_utils import JitTestCase
37
38
39if __name__ == "__main__":
40    raise RuntimeError(
41        "This test file is not meant to be run directly, use:\n\n"
42        "\tpython test/test_jit.py TESTNAME\n\n"
43        "instead."
44    )
45
46
47# Tests for JIT forward hooks and pre-hooks
48class TestHooks(JitTestCase):
49    def test_module_no_forward_input(self):
50        self.checkModule(create_module_no_forward_input(), ())
51
52    def test_submodule_no_forward_input(self):
53        self.checkModule(create_submodule_no_forward_input(), ())
54
55    def test_module_forward_multiple_inputs(self):
56        self.checkModule(
57            create_module_forward_multiple_inputs(), (["a"], "no_pre_hook")
58        )
59
60    def test_module_multiple_hooks_multiple_inputs(self):
61        self.checkModule(
62            create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook")
63        )
64
65    def test_module_forward_single_input(self):
66        self.checkModule(create_module_forward_single_input(), ("a",))
67
68    def test_module_same_hook_repeated(self):
69        self.checkModule(create_module_same_hook_repeated(), ("a",))
70
71    def test_module_hook_return_nothing(self):
72        self.checkModule(create_module_hook_return_nothing(), ("a",))
73
74    def test_module_multiple_hooks_single_input(self):
75        self.checkModule(create_module_multiple_hooks_single_input(), ("a",))
76
77    def test_submodule_forward_multiple_inputs(self):
78        self.checkModule(
79            create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook")
80        )
81
82    def test_submodule_multiple_hooks_multiple_inputs(self):
83        self.checkModule(
84            create_submodule_multiple_hooks_multiple_inputs(),
85            (["a"], "no_pre_hook"),
86        )
87
88    def test_submodule_forward_single_input(self):
89        self.checkModule(create_submodule_forward_single_input(), ("a",))
90
91    def test_submodule_called_directly_with_hooks(self):
92        module = create_submodule_to_call_directly_with_hooks()
93        module_scripted = torch.jit.script(module)
94
95        submodule = module.submodule
96        scripted_submodule = module_scripted.submodule
97
98        self.assertEqual(submodule("a"), scripted_submodule("a"))
99
100    def test_submodule_same_hook_repeated(self):
101        self.checkModule(create_submodule_same_hook_repeated(), ("a",))
102
103    def test_submodule_hook_return_nothing(self):
104        self.checkModule(create_submodule_hook_return_nothing(), ("a",))
105
106    def test_submodule_multiple_hooks_single_input(self):
107        self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"]))
108
109    def test_forward_tuple_input(self):
110        self.checkModule(create_forward_tuple_input(), ((3,),))
111
112    def test_submodule_forward_single_input_return_not_tupled(self):
113        self.checkModule(
114            create_submodule_forward_single_input_return_not_tupled(), ("a",)
115        )
116
117    def test_hook_method_name_collision(self):
118        # Hooks can't have the same name as methods.
119        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
120
121        def foo(self, input: Tuple[str]) -> Tuple[str]:
122            assert self.name == "inner_mod_name"
123            assert input[0] == "a_outermod"
124            return ("pre_hook_override_name",)
125
126        m.submodule.register_forward_pre_hook(foo)
127
128        with self.assertRaisesRegex(
129            RuntimeError,
130            "Can't define hook: foo on class: .+ "
131            "because a method or hook with that name already exists.",
132        ):
133            torch.jit.script(m)
134
135    def test_hook_hook_name_collision(self):
136        # Test edge case of two hooks sharing name but not python definition
137        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
138
139        def prehook(self, input: Tuple[str]) -> Tuple[str]:
140            return "This is the first hook"
141
142        m.submodule.register_forward_pre_hook(prehook)
143
144        def prehook(self, input: Tuple[str]) -> Tuple[str]:
145            return "This is the second hook"
146
147        m.submodule.register_forward_pre_hook(prehook)
148
149        with self.assertRaisesRegex(
150            RuntimeError,
151            "Pre-hook '.+' on .+ has at least two different python "
152            "definitions. Please use unique names for all hooks.",
153        ):
154            torch.jit.script(m)
155
156        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
157
158        def hook(self, input: Tuple[str], output: str):
159            return "This is the first hook"
160
161        m.submodule.register_forward_hook(hook)
162
163        def hook(self, input: Tuple[str]):
164            return "This is the second hook"
165
166        m.submodule.register_forward_hook(hook)
167
168        with self.assertRaisesRegex(
169            RuntimeError,
170            "Hook '.+' on .+ has at least two different python "
171            "definitions. Please use unique names for all hooks.",
172        ):
173            torch.jit.script(m)
174
175    def test_module_direct_forward_invocation(self):
176        # Test that hooks are only invoked when the module is
177        # called directly and not when forward is called.
178        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
179
180        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
181            return ("pre_hook_override_name",)
182
183        def forward_hook(self, input: Tuple[str], output: str):
184            assert self.name == "outer_mod_name"
185            assert input == ("pre_hook_override_name",)
186            output = output + "_fh"
187            return output
188
189        m.register_forward_pre_hook(pre_hook)
190        m.register_forward_hook(forward_hook)
191
192        m_scripted = torch.jit.script(m)
193
194        self.assertEqual(m.forward("a"), m_scripted.forward("a"))
195        self.assertNotEqual(m_scripted("a"), m_scripted.forward("a"))
196
197    def test_submodule_direct_forward_invocation(self):
198        m_submod_forward_call = ModuleDirectforwardSubmodCall(
199            "outer_mod_name", "inner_mod_name"
200        )
201        m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
202
203        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
204            return ("pre_hook_override_name",)
205
206        def forward_hook(self, input: Tuple[str], output: str):
207            assert input == ("pre_hook_override_name",)
208            return output + "_fh"
209
210        m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook)
211        m_submod_forward_call.submodule.register_forward_hook(forward_hook)
212        m_submod_call.submodule.register_forward_pre_hook(pre_hook)
213        m_submod_call.submodule.register_forward_hook(forward_hook)
214
215        m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call)
216        m_submod_call_scripted = torch.jit.script(m_submod_call)
217
218        self.assertEqual(
219            m_submod_forward_call_scripted("a"), m_submod_forward_call("a")
220        )
221        self.assertNotEqual(
222            m_submod_forward_call_scripted("a"), m_submod_call_scripted("a")
223        )
224
225    # TODO: add this test back once figured out how to print error msg
226    @unittest.skip
227    def test_hook_compilation_hint(self):
228        # Tests if hook error message is printed out if erroring after schema check.
229        # Useful for when user is scripting hooks while not aware of it.
230        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
231
232        def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
233            assert self.name == "outer_mod_name"
234            assert input[4] == "a"  # out of bounds tuple range
235            return ("pre_hook_override_name",)
236
237        m.register_forward_pre_hook(pre_hook)
238
239        with self.assertRaisesRegex(
240            RuntimeError,
241            "This error occurred while scripting the forward pre-hook 'pre_hook'",
242        ):
243            torch.jit.script(m)
244
245    def test_wrong_pre_hook_signatures(self):
246        # correct signature: pre_hook_c(self, input: Tuple[str])
247        def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]:
248            return ("hello",)
249
250        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
251        m.register_forward_pre_hook(pre_hook_wrong_input1)
252
253        with self.assertRaisesRegex(
254            RuntimeError,
255            "has the wrong inner types for the input tuple argument",
256        ):
257            torch.jit.script(m)
258
259        def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]:
260            return ("hello",)
261
262        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
263        m.register_forward_pre_hook(pre_hook_wrong_input2)
264
265        with self.assertRaisesRegex(
266            RuntimeError,
267            "was expected to only have exactly 2 inputs but it had 3 inputs",
268        ):
269            torch.jit.script(m)
270
271        def pre_hook_wrong_input3(self, input: int) -> Tuple[str]:
272            return ("hello",)
273
274        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
275        m.register_forward_pre_hook(pre_hook_wrong_input3)
276
277        with self.assertRaisesRegex(
278            RuntimeError,
279            "expected the input argument to be typed as a Tuple but"
280            " found type: 'int' instead",
281        ):
282            torch.jit.script(m)
283
284        def pre_hook_wrong_output(self, input: Tuple[str]) -> int:
285            return 1  # expecting Tuple[str], str, or None
286
287        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
288        m.register_forward_pre_hook(pre_hook_wrong_output)
289
290        with self.assertRaisesRegex(
291            RuntimeError,
292            "returned the wrong type of: 'int'",
293        ):
294            torch.jit.script(m)
295
296        def pre_hook_no_output_annotation(self, input: Tuple[str]):
297            return 1  # expecting Tuple[str], str, or None
298
299        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
300        m.register_forward_pre_hook(pre_hook_no_output_annotation)
301
302        with self.assertRaisesRegex(
303            RuntimeError,
304            "is missing a return annotation. Return annotations"
305            " are required, please add one.",
306        ):
307            torch.jit.script(m)
308
309        def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]:
310            return (11,)  # doesn't work with eager, inner tuple lost
311
312        m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name")
313        m.register_forward_pre_hook(pre_hook_wrong_tuple_return)
314
315        with self.assertRaisesRegex(
316            RuntimeError,
317            "When forward has a single tuple input argument, "
318            "the return needs to be 'None' or a nested tuple containing "
319            r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'",
320        ):
321            torch.jit.script(m)
322
323    def test_wrong_hook_signatures(self):
324        # correct signature:
325        #   def forward_hook(self, input: Tuple[str], output: str)
326        def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str):
327            return output
328
329        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
330        m.register_forward_hook(forward_hook_wrong_input1)
331
332        with self.assertRaisesRegex(
333            RuntimeError,
334            "has the wrong number of contained types for the "
335            r"input argument's Tuple. Received type: 'Tuple\[str, str\]'",
336        ):
337            torch.jit.script(m)
338
339        def forward_hook_wrong_input2(self, input: str, output: str):
340            return output
341
342        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
343        m.register_forward_hook(forward_hook_wrong_input2)
344
345        with self.assertRaisesRegex(
346            RuntimeError,
347            "expected the input argument to be typed as a Tuple "
348            "but found type: 'str' instead.",
349        ):
350            torch.jit.script(m)
351
352        def forward_hook_wrong_input3(self, input: Tuple[None], output: str):
353            return output
354
355        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
356        m.register_forward_hook(forward_hook_wrong_input3)
357
358        with self.assertRaisesRegex(
359            RuntimeError,
360            "has the wrong inner types for the input tuple"
361            r" argument. Received type: 'Tuple\[NoneType\]'",
362        ):
363            torch.jit.script(m)
364
365        def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]):
366            return output
367
368        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
369        m.register_forward_hook(forward_hook_wrong_output)
370
371        with self.assertRaisesRegex(
372            RuntimeError,
373            "has the wrong type for the output argument. Received"
374            r" type: 'Tuple\[str\]'. Expected type: 'str'",
375        ):
376            torch.jit.script(m)
377
378        def forward_hook_correct(self, input: Tuple[str], output: str):
379            return (output,)
380
381        def forward_hook_wrong_output_from_prev_hook(
382            self, input: Tuple[str], output: str
383        ):
384            return output
385
386        m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
387        m.register_forward_hook(forward_hook_correct)
388        m.register_forward_hook(forward_hook_wrong_output_from_prev_hook)
389
390        with self.assertRaisesRegex(
391            RuntimeError,
392            "has the wrong type for the output argument. "
393            r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
394        ):
395            torch.jit.script(m)
396