1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport os 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker# grab modules from test_jit_hooks.cpp 9*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 10*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks_modules import ( 12*da0073e9SAndroid Build Coastguard Worker create_forward_tuple_input, 13*da0073e9SAndroid Build Coastguard Worker create_module_forward_multiple_inputs, 14*da0073e9SAndroid Build Coastguard Worker create_module_forward_single_input, 15*da0073e9SAndroid Build Coastguard Worker create_module_hook_return_nothing, 16*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_multiple_inputs, 17*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_single_input, 18*da0073e9SAndroid Build Coastguard Worker create_module_no_forward_input, 19*da0073e9SAndroid Build Coastguard Worker create_module_same_hook_repeated, 20*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_multiple_inputs, 21*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_single_input, 22*da0073e9SAndroid Build Coastguard Worker create_submodule_hook_return_nothing, 23*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_multiple_inputs, 24*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_single_input, 25*da0073e9SAndroid Build Coastguard Worker create_submodule_same_hook_repeated, 26*da0073e9SAndroid Build Coastguard Worker create_submodule_to_call_directly_with_hooks, 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker# Create saved modules for JIT forward hooks and pre-hooks 31*da0073e9SAndroid Build Coastguard Workerdef main(): 32*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 33*da0073e9SAndroid Build Coastguard Worker description="Serialize a script modules with hooks attached" 34*da0073e9SAndroid Build Coastguard Worker ) 35*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--export-script-module-to", required=True) 36*da0073e9SAndroid Build Coastguard Worker options = parser.parse_args() 37*da0073e9SAndroid Build Coastguard Worker global save_name 38*da0073e9SAndroid Build Coastguard Worker save_name = options.export_script_module_to + "_" 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker tests = [ 41*da0073e9SAndroid Build Coastguard Worker ( 42*da0073e9SAndroid Build Coastguard Worker "test_submodule_forward_single_input", 43*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_single_input(), 44*da0073e9SAndroid Build Coastguard Worker ), 45*da0073e9SAndroid Build Coastguard Worker ( 46*da0073e9SAndroid Build Coastguard Worker "test_submodule_forward_multiple_inputs", 47*da0073e9SAndroid Build Coastguard Worker create_submodule_forward_multiple_inputs(), 48*da0073e9SAndroid Build Coastguard Worker ), 49*da0073e9SAndroid Build Coastguard Worker ( 50*da0073e9SAndroid Build Coastguard Worker "test_submodule_multiple_hooks_single_input", 51*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_single_input(), 52*da0073e9SAndroid Build Coastguard Worker ), 53*da0073e9SAndroid Build Coastguard Worker ( 54*da0073e9SAndroid Build Coastguard Worker "test_submodule_multiple_hooks_multiple_inputs", 55*da0073e9SAndroid Build Coastguard Worker create_submodule_multiple_hooks_multiple_inputs(), 56*da0073e9SAndroid Build Coastguard Worker ), 57*da0073e9SAndroid Build Coastguard Worker ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()), 58*da0073e9SAndroid Build Coastguard Worker ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()), 59*da0073e9SAndroid Build Coastguard Worker ("test_module_forward_single_input", create_module_forward_single_input()), 60*da0073e9SAndroid Build Coastguard Worker ( 61*da0073e9SAndroid Build Coastguard Worker "test_module_forward_multiple_inputs", 62*da0073e9SAndroid Build Coastguard Worker create_module_forward_multiple_inputs(), 63*da0073e9SAndroid Build Coastguard Worker ), 64*da0073e9SAndroid Build Coastguard Worker ( 65*da0073e9SAndroid Build Coastguard Worker "test_module_multiple_hooks_single_input", 66*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_single_input(), 67*da0073e9SAndroid Build Coastguard Worker ), 68*da0073e9SAndroid Build Coastguard Worker ( 69*da0073e9SAndroid Build Coastguard Worker "test_module_multiple_hooks_multiple_inputs", 70*da0073e9SAndroid Build Coastguard Worker create_module_multiple_hooks_multiple_inputs(), 71*da0073e9SAndroid Build Coastguard Worker ), 72*da0073e9SAndroid Build Coastguard Worker ("test_module_hook_return_nothing", create_module_hook_return_nothing()), 73*da0073e9SAndroid Build Coastguard Worker ("test_module_same_hook_repeated", create_module_same_hook_repeated()), 74*da0073e9SAndroid Build Coastguard Worker ("test_module_no_forward_input", create_module_no_forward_input()), 75*da0073e9SAndroid Build Coastguard Worker ("test_forward_tuple_input", create_forward_tuple_input()), 76*da0073e9SAndroid Build Coastguard Worker ( 77*da0073e9SAndroid Build Coastguard Worker "test_submodule_to_call_directly_with_hooks", 78*da0073e9SAndroid Build Coastguard Worker create_submodule_to_call_directly_with_hooks(), 79*da0073e9SAndroid Build Coastguard Worker ), 80*da0073e9SAndroid Build Coastguard Worker ] 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker for name, model in tests: 83*da0073e9SAndroid Build Coastguard Worker m_scripted = torch.jit.script(model) 84*da0073e9SAndroid Build Coastguard Worker filename = save_name + name + ".pt" 85*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m_scripted, filename) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker print("OK: completed saving modules with hooks!") 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 91*da0073e9SAndroid Build Coastguard Worker main() 92