1import argparse 2import os 3import sys 4 5import torch 6 7 8# grab modules from test_jit_hooks.cpp 9pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 10sys.path.append(pytorch_test_dir) 11from jit.test_hooks_modules import ( 12 create_forward_tuple_input, 13 create_module_forward_multiple_inputs, 14 create_module_forward_single_input, 15 create_module_hook_return_nothing, 16 create_module_multiple_hooks_multiple_inputs, 17 create_module_multiple_hooks_single_input, 18 create_module_no_forward_input, 19 create_module_same_hook_repeated, 20 create_submodule_forward_multiple_inputs, 21 create_submodule_forward_single_input, 22 create_submodule_hook_return_nothing, 23 create_submodule_multiple_hooks_multiple_inputs, 24 create_submodule_multiple_hooks_single_input, 25 create_submodule_same_hook_repeated, 26 create_submodule_to_call_directly_with_hooks, 27) 28 29 30# Create saved modules for JIT forward hooks and pre-hooks 31def main(): 32 parser = argparse.ArgumentParser( 33 description="Serialize a script modules with hooks attached" 34 ) 35 parser.add_argument("--export-script-module-to", required=True) 36 options = parser.parse_args() 37 global save_name 38 save_name = options.export_script_module_to + "_" 39 40 tests = [ 41 ( 42 "test_submodule_forward_single_input", 43 create_submodule_forward_single_input(), 44 ), 45 ( 46 "test_submodule_forward_multiple_inputs", 47 create_submodule_forward_multiple_inputs(), 48 ), 49 ( 50 "test_submodule_multiple_hooks_single_input", 51 create_submodule_multiple_hooks_single_input(), 52 ), 53 ( 54 "test_submodule_multiple_hooks_multiple_inputs", 55 create_submodule_multiple_hooks_multiple_inputs(), 56 ), 57 ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()), 58 ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()), 59 ("test_module_forward_single_input", create_module_forward_single_input()), 60 ( 61 "test_module_forward_multiple_inputs", 62 create_module_forward_multiple_inputs(), 63 ), 64 ( 65 "test_module_multiple_hooks_single_input", 66 create_module_multiple_hooks_single_input(), 67 ), 68 ( 69 "test_module_multiple_hooks_multiple_inputs", 70 create_module_multiple_hooks_multiple_inputs(), 71 ), 72 ("test_module_hook_return_nothing", create_module_hook_return_nothing()), 73 ("test_module_same_hook_repeated", create_module_same_hook_repeated()), 74 ("test_module_no_forward_input", create_module_no_forward_input()), 75 ("test_forward_tuple_input", create_forward_tuple_input()), 76 ( 77 "test_submodule_to_call_directly_with_hooks", 78 create_submodule_to_call_directly_with_hooks(), 79 ), 80 ] 81 82 for name, model in tests: 83 m_scripted = torch.jit.script(model) 84 filename = save_name + name + ".pt" 85 torch.jit.save(m_scripted, filename) 86 87 print("OK: completed saving modules with hooks!") 88 89 90if __name__ == "__main__": 91 main() 92