xref: /aosp_15_r20/external/pytorch/test/jit_hooks/test_jit_hooks.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/script.h>
2 
3 #include <memory>
4 #include <string>
5 #include <sstream>
6 #include <vector>
7 
8 #include <iostream>
9 
test_module_forward_invocation_no_hooks_run(const std::string & path_to_exported_script_module)10 void test_module_forward_invocation_no_hooks_run(
11     const std::string &path_to_exported_script_module) {
12   std::cout << "testing: "
13             << "test_module_forward_invocation_no_hooks_run" << std::endl;
14   torch::jit::Module module =
15       torch::jit::load(path_to_exported_script_module + "_" +
16                        "test_module_forward_multiple_inputs" + ".pt");
17   std::vector<torch::jit::IValue> inputs = {torch::List<std::string>({"a"}),
18                                             torch::jit::IValue("no_pre_hook")};
19 
20   auto output = module(inputs);
21   auto output_forward = module.forward(inputs);
22   torch::jit::IValue correct_direct_output =
23       std::tuple<torch::List<std::string>, std::string>(
24           {"a", "outer_mod_name", "inner_mod_name"}, "no_pre_hook_");
25   std::cout << "----- module output: " << output << std::endl;
26   std::cout << "----- module forward output: " << output_forward << std::endl;
27   AT_ASSERT(correct_direct_output == output_forward);
28 }
29 
test_submodule_called_directly_with_hooks(const std::string & path_to_exported_script_module)30 void test_submodule_called_directly_with_hooks(
31     const std::string &path_to_exported_script_module) {
32   std::cout << "testing: "
33             << "test_submodule_to_call_directly_with_hooks" << std::endl;
34   torch::jit::Module module =
35       torch::jit::load(path_to_exported_script_module + "_" +
36                        "test_submodule_to_call_directly_with_hooks" + ".pt");
37   torch::jit::Module submodule = *module.modules().begin();
38   std::vector<torch::jit::IValue> inputs = {"a"};
39 
40   auto output = submodule(inputs);
41   torch::jit::IValue correct_output = "pre_hook_override_name_inner_mod_fh";
42   std::cout << "----- submodule's output: " << output << std::endl;
43   std::cout << "----- expected output   : " << correct_output << std::endl;
44   AT_ASSERT(correct_output == correct_output);
45 }
46 
47 struct HooksTestCase {
48   std::string name;
49   std::vector<torch::jit::IValue> inputs;
50   torch::jit::IValue output;
HooksTestCaseHooksTestCase51   HooksTestCase(std::string name, std::vector<torch::jit::IValue> inputs,
52                 torch::jit::IValue output)
53       : name(name), inputs(std::move(inputs)), output(std::move(output)) {}
54 };
55 
main(int argc,const char * argv[])56 int main(int argc, const char *argv[]) {
57   if (argc != 2) {
58     std::cerr << "usage: test_jit_hooks <path-to-exported-script-module>\n";
59     return -1;
60   }
61   const std::string path_to_exported_script_module = argv[1];
62   std::cout << "path to exported module:" << path_to_exported_script_module
63             << std::endl;
64   std::cout << "Tesing JIT Hooks in CPP" << std::endl;
65 
66   // Note: Modules loaded in this file are produced in /test/jit_hooks/model.py
67 
68   std::vector<HooksTestCase> test_cases = {
69       HooksTestCase("test_submodule_multiple_hooks_single_input",
70                     {torch::jit::IValue("a")},
71                     "pre_hook_override_name2_inner_mod_fwh1"),
72       HooksTestCase("test_submodule_hook_return_nothing",
73                     {torch::jit::IValue("a")}, "a_outermod_inner_mod"),
74       HooksTestCase("test_submodule_same_hook_repeated",
75                     {torch::jit::IValue("a")},
76                     "a_outermod_ph_ph_inner_mod_fh_fh"),
77       HooksTestCase("test_submodule_forward_single_input",
78                     {torch::jit::IValue("a")},
79                     "pre_hook_override_name_inner_mod"),
80       HooksTestCase(
81           "test_submodule_multiple_hooks_multiple_inputs",
82           {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
83           std::tuple<torch::List<std::string>, std::string>(
84               {"pre_hook_override_name", "inner_mod_name"},
85               "pre_hook_override2_fh1_fh2")),
86       HooksTestCase(
87           "test_submodule_forward_multiple_inputs",
88           {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
89           std::tuple<torch::List<std::string>, std::string>(
90               {"pre_hook_override_name", "inner_mod_name"},
91               "pre_hook_override_fh")),
92       HooksTestCase("test_module_forward_single_input",
93                     {torch::jit::IValue("a")},
94                     "pre_hook_override_name_outermod_inner_mod_fh"),
95       HooksTestCase("test_module_multiple_hooks_single_input",
96                     {torch::jit::IValue("a")},
97                     "pre_hook_override_name2_outermod_inner_mod_fh1_fh2"),
98       HooksTestCase("test_module_hook_return_nothing",
99                     {torch::jit::IValue("a")}, "a_outermod_inner_mod"),
100       HooksTestCase("test_module_same_hook_repeated", {torch::jit::IValue("a")},
101                     "a_ph_ph_outermod_inner_mod_fh_fh"),
102       HooksTestCase(
103           "test_module_forward_multiple_inputs",
104           {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
105           std::tuple<torch::List<std::string>, std::string>(
106               {"pre_hook_override_name", "outer_mod_name", "inner_mod_name"},
107               "pre_hook_override_fh")),
108       HooksTestCase(
109           "test_module_multiple_hooks_multiple_inputs",
110           {torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")},
111           std::tuple<torch::List<std::string>, std::string>(
112               {"pre_hook_override_name2", "outer_mod_name", "inner_mod_name"},
113               "pre_hook_override_fh1_fh2")),
114       HooksTestCase("test_module_no_forward_input", {}, torch::jit::IValue()),
115       HooksTestCase("test_forward_tuple_input", {std::tuple<int>(11)},
116                     {std::tuple<int>(11)}),
117   };
118 
119   for (HooksTestCase &test_case : test_cases) {
120     std::cout << "testing: " << test_case.name << std::endl;
121     torch::jit::Module module = torch::jit::load(
122         path_to_exported_script_module + "_" + test_case.name + ".pt");
123     torch::jit::IValue output = module(test_case.inputs);
124     std::cout << "----- module's output: " << output << std::endl;
125     std::cout << "----- expected output: " << test_case.output << std::endl;
126     AT_ASSERT(output == test_case.output);
127   }
128 
129   // special test cases that don't call the imported module directly
130   test_module_forward_invocation_no_hooks_run(path_to_exported_script_module);
131   test_submodule_called_directly_with_hooks(path_to_exported_script_module);
132 
133   std::cout << "JIT CPP Hooks okay!" << std::endl;
134 
135   return 0;
136 }
137