xref: /aosp_15_r20/external/pytorch/test/jit_hooks/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os
3import sys
4
5import torch
6
7
8# grab modules from test_jit_hooks.cpp
9pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10sys.path.append(pytorch_test_dir)
11from jit.test_hooks_modules import (
12    create_forward_tuple_input,
13    create_module_forward_multiple_inputs,
14    create_module_forward_single_input,
15    create_module_hook_return_nothing,
16    create_module_multiple_hooks_multiple_inputs,
17    create_module_multiple_hooks_single_input,
18    create_module_no_forward_input,
19    create_module_same_hook_repeated,
20    create_submodule_forward_multiple_inputs,
21    create_submodule_forward_single_input,
22    create_submodule_hook_return_nothing,
23    create_submodule_multiple_hooks_multiple_inputs,
24    create_submodule_multiple_hooks_single_input,
25    create_submodule_same_hook_repeated,
26    create_submodule_to_call_directly_with_hooks,
27)
28
29
30# Create saved modules for JIT forward hooks and pre-hooks
31def main():
32    parser = argparse.ArgumentParser(
33        description="Serialize a script modules with hooks attached"
34    )
35    parser.add_argument("--export-script-module-to", required=True)
36    options = parser.parse_args()
37    global save_name
38    save_name = options.export_script_module_to + "_"
39
40    tests = [
41        (
42            "test_submodule_forward_single_input",
43            create_submodule_forward_single_input(),
44        ),
45        (
46            "test_submodule_forward_multiple_inputs",
47            create_submodule_forward_multiple_inputs(),
48        ),
49        (
50            "test_submodule_multiple_hooks_single_input",
51            create_submodule_multiple_hooks_single_input(),
52        ),
53        (
54            "test_submodule_multiple_hooks_multiple_inputs",
55            create_submodule_multiple_hooks_multiple_inputs(),
56        ),
57        ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
58        ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
59        ("test_module_forward_single_input", create_module_forward_single_input()),
60        (
61            "test_module_forward_multiple_inputs",
62            create_module_forward_multiple_inputs(),
63        ),
64        (
65            "test_module_multiple_hooks_single_input",
66            create_module_multiple_hooks_single_input(),
67        ),
68        (
69            "test_module_multiple_hooks_multiple_inputs",
70            create_module_multiple_hooks_multiple_inputs(),
71        ),
72        ("test_module_hook_return_nothing", create_module_hook_return_nothing()),
73        ("test_module_same_hook_repeated", create_module_same_hook_repeated()),
74        ("test_module_no_forward_input", create_module_no_forward_input()),
75        ("test_forward_tuple_input", create_forward_tuple_input()),
76        (
77            "test_submodule_to_call_directly_with_hooks",
78            create_submodule_to_call_directly_with_hooks(),
79        ),
80    ]
81
82    for name, model in tests:
83        m_scripted = torch.jit.script(model)
84        filename = save_name + name + ".pt"
85        torch.jit.save(m_scripted, filename)
86
87    print("OK: completed saving modules with hooks!")
88
89
90if __name__ == "__main__":
91    main()
92