xref: /aosp_15_r20/external/pytorch/test/jit_hooks/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport os
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker# grab modules from test_jit_hooks.cpp
9*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
10*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks_modules import (
12*da0073e9SAndroid Build Coastguard Worker    create_forward_tuple_input,
13*da0073e9SAndroid Build Coastguard Worker    create_module_forward_multiple_inputs,
14*da0073e9SAndroid Build Coastguard Worker    create_module_forward_single_input,
15*da0073e9SAndroid Build Coastguard Worker    create_module_hook_return_nothing,
16*da0073e9SAndroid Build Coastguard Worker    create_module_multiple_hooks_multiple_inputs,
17*da0073e9SAndroid Build Coastguard Worker    create_module_multiple_hooks_single_input,
18*da0073e9SAndroid Build Coastguard Worker    create_module_no_forward_input,
19*da0073e9SAndroid Build Coastguard Worker    create_module_same_hook_repeated,
20*da0073e9SAndroid Build Coastguard Worker    create_submodule_forward_multiple_inputs,
21*da0073e9SAndroid Build Coastguard Worker    create_submodule_forward_single_input,
22*da0073e9SAndroid Build Coastguard Worker    create_submodule_hook_return_nothing,
23*da0073e9SAndroid Build Coastguard Worker    create_submodule_multiple_hooks_multiple_inputs,
24*da0073e9SAndroid Build Coastguard Worker    create_submodule_multiple_hooks_single_input,
25*da0073e9SAndroid Build Coastguard Worker    create_submodule_same_hook_repeated,
26*da0073e9SAndroid Build Coastguard Worker    create_submodule_to_call_directly_with_hooks,
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# Create saved modules for JIT forward hooks and pre-hooks
31*da0073e9SAndroid Build Coastguard Workerdef main():
32*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(
33*da0073e9SAndroid Build Coastguard Worker        description="Serialize a script modules with hooks attached"
34*da0073e9SAndroid Build Coastguard Worker    )
35*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--export-script-module-to", required=True)
36*da0073e9SAndroid Build Coastguard Worker    options = parser.parse_args()
37*da0073e9SAndroid Build Coastguard Worker    global save_name
38*da0073e9SAndroid Build Coastguard Worker    save_name = options.export_script_module_to + "_"
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    tests = [
41*da0073e9SAndroid Build Coastguard Worker        (
42*da0073e9SAndroid Build Coastguard Worker            "test_submodule_forward_single_input",
43*da0073e9SAndroid Build Coastguard Worker            create_submodule_forward_single_input(),
44*da0073e9SAndroid Build Coastguard Worker        ),
45*da0073e9SAndroid Build Coastguard Worker        (
46*da0073e9SAndroid Build Coastguard Worker            "test_submodule_forward_multiple_inputs",
47*da0073e9SAndroid Build Coastguard Worker            create_submodule_forward_multiple_inputs(),
48*da0073e9SAndroid Build Coastguard Worker        ),
49*da0073e9SAndroid Build Coastguard Worker        (
50*da0073e9SAndroid Build Coastguard Worker            "test_submodule_multiple_hooks_single_input",
51*da0073e9SAndroid Build Coastguard Worker            create_submodule_multiple_hooks_single_input(),
52*da0073e9SAndroid Build Coastguard Worker        ),
53*da0073e9SAndroid Build Coastguard Worker        (
54*da0073e9SAndroid Build Coastguard Worker            "test_submodule_multiple_hooks_multiple_inputs",
55*da0073e9SAndroid Build Coastguard Worker            create_submodule_multiple_hooks_multiple_inputs(),
56*da0073e9SAndroid Build Coastguard Worker        ),
57*da0073e9SAndroid Build Coastguard Worker        ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
58*da0073e9SAndroid Build Coastguard Worker        ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
59*da0073e9SAndroid Build Coastguard Worker        ("test_module_forward_single_input", create_module_forward_single_input()),
60*da0073e9SAndroid Build Coastguard Worker        (
61*da0073e9SAndroid Build Coastguard Worker            "test_module_forward_multiple_inputs",
62*da0073e9SAndroid Build Coastguard Worker            create_module_forward_multiple_inputs(),
63*da0073e9SAndroid Build Coastguard Worker        ),
64*da0073e9SAndroid Build Coastguard Worker        (
65*da0073e9SAndroid Build Coastguard Worker            "test_module_multiple_hooks_single_input",
66*da0073e9SAndroid Build Coastguard Worker            create_module_multiple_hooks_single_input(),
67*da0073e9SAndroid Build Coastguard Worker        ),
68*da0073e9SAndroid Build Coastguard Worker        (
69*da0073e9SAndroid Build Coastguard Worker            "test_module_multiple_hooks_multiple_inputs",
70*da0073e9SAndroid Build Coastguard Worker            create_module_multiple_hooks_multiple_inputs(),
71*da0073e9SAndroid Build Coastguard Worker        ),
72*da0073e9SAndroid Build Coastguard Worker        ("test_module_hook_return_nothing", create_module_hook_return_nothing()),
73*da0073e9SAndroid Build Coastguard Worker        ("test_module_same_hook_repeated", create_module_same_hook_repeated()),
74*da0073e9SAndroid Build Coastguard Worker        ("test_module_no_forward_input", create_module_no_forward_input()),
75*da0073e9SAndroid Build Coastguard Worker        ("test_forward_tuple_input", create_forward_tuple_input()),
76*da0073e9SAndroid Build Coastguard Worker        (
77*da0073e9SAndroid Build Coastguard Worker            "test_submodule_to_call_directly_with_hooks",
78*da0073e9SAndroid Build Coastguard Worker            create_submodule_to_call_directly_with_hooks(),
79*da0073e9SAndroid Build Coastguard Worker        ),
80*da0073e9SAndroid Build Coastguard Worker    ]
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    for name, model in tests:
83*da0073e9SAndroid Build Coastguard Worker        m_scripted = torch.jit.script(model)
84*da0073e9SAndroid Build Coastguard Worker        filename = save_name + name + ".pt"
85*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(m_scripted, filename)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    print("OK: completed saving modules with hooks!")
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
91*da0073e9SAndroid Build Coastguard Worker    main()
92