xref: /aosp_15_r20/external/pytorch/test/cpp/aoti_inference/test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <filesystem>
3 #include <string>
4 #include <vector>
5 
6 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
7 #if defined(USE_CUDA) || defined(USE_ROCM)
8 #include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
9 #endif
10 #include <torch/script.h>
11 #include <torch/torch.h>
12 
13 #define STR_VALUE(x) #x
14 #define STRINGIZE(x) STR_VALUE(x)
15 
16 namespace {
17 
test_aoti(const std::string & device,bool use_runtime_constant_folding)18 void test_aoti(const std::string& device, bool use_runtime_constant_folding) {
19   torch::NoGradGuard no_grad;
20 
21   std::string data_path =
22       (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
23            .string();
24   torch::jit::script::Module data_loader = torch::jit::load(data_path);
25   std::string suffix = use_runtime_constant_folding
26       ? device + "_use_runtime_constant_folding"
27       : device;
28   std::string path_attr = "model_so_path_" + suffix;
29   std::string inputs_attr = "inputs_" + suffix;
30   std::string outputs_attr = "outputs_" + suffix;
31   const auto& model_so_path = data_loader.attr(path_attr.c_str()).toStringRef();
32   auto input_tensors =
33       data_loader.attr(inputs_attr.c_str()).toTensorList().vec();
34   const auto& ref_output_tensors =
35       data_loader.attr(outputs_attr.c_str()).toTensorList().vec();
36 
37   std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner;
38   if (device == "cpu") {
39     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
40         model_so_path);
41 #if defined(USE_CUDA) || defined(USE_ROCM)
42   } else if (device == "cuda") {
43     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
44         model_so_path);
45 #endif
46   } else {
47     testing::AssertionFailure() << "unsupported device: " << device;
48   }
49   auto actual_output_tensors = runner->run(input_tensors);
50   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
51 }
52 
test_aoti_script(const std::string & device)53 void test_aoti_script(const std::string& device) {
54   torch::NoGradGuard no_grad;
55 
56   std::string script_model = "script_model_" + device + ".pt";
57   std::string model_path =
58       (std::filesystem::path(
59            STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / script_model.c_str())
60            .string();
61   torch::jit::script::Module model = torch::jit::load(model_path);
62 
63   std::string sample_data_path =
64       (std::filesystem::path(
65            STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "script_data.pt")
66            .string();
67   torch::jit::script::Module sample_data = torch::jit::load(sample_data_path);
68   std::string inputs_attr = "inputs_" + device;
69   std::string outputs_attr = "outputs_" + device;
70   const auto& inputs = sample_data.attr(inputs_attr.c_str()).toList().vec();
71   const auto& ref_output_tensors =
72       sample_data.attr(outputs_attr.c_str()).toTensorVector();
73   auto outputs = model.forward(inputs).toTuple()->elements();
74   ASSERT_EQ(outputs.size(), ref_output_tensors.size());
75   for (size_t i = 0; i < ref_output_tensors.size(); i++) {
76     ASSERT_TRUE(torch::allclose(outputs[i].toTensor(), ref_output_tensors[i]));
77   }
78 }
79 
test_aoti_constants_update(const std::string & device,bool use_runtime_constant_folding)80 void test_aoti_constants_update(
81     const std::string& device,
82     bool use_runtime_constant_folding) {
83   torch::NoGradGuard no_grad;
84 
85   std::string data_path =
86       (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
87            .string();
88 
89   torch::jit::script::Module data_loader = torch::jit::load(data_path);
90   std::string suffix = use_runtime_constant_folding
91       ? device + "_use_runtime_constant_folding"
92       : device;
93   std::string path_attr = "model_so_path_" + suffix;
94   std::string inputs_attr = "inputs_" + suffix;
95   std::string outputs_attr = "outputs_" + suffix;
96   std::string weights_attr = "w_pre_" + suffix;
97   std::string add_attr = "w_add_" + suffix;
98   const auto& model_so_path = data_loader.attr(path_attr.c_str()).toStringRef();
99   auto input_tensors =
100       data_loader.attr(inputs_attr.c_str()).toTensorList().vec();
101   const auto& ref_output_tensors =
102       data_loader.attr(outputs_attr.c_str()).toTensorList().vec();
103 
104   const auto& weight_tensors =
105       data_loader.attr(weights_attr.c_str()).toTensor();
106   const auto& add_tensors = data_loader.attr(add_attr.c_str()).toTensor();
107 
108   torch::inductor::TensorConstantMap missing_map, rand_map, real_map;
109   missing_map.emplace("L__self___w_pre", new at::Tensor(at::randn({4, 4})));
110   rand_map.emplace("L__self___w_pre", new at::Tensor(at::randn({10})));
111   rand_map.emplace("L__self___w_add", new at::Tensor(at::randn({10})));
112   real_map.emplace("L__self___w_pre", new at::Tensor(weight_tensors));
113   real_map.emplace("L__self___w_add", new at::Tensor(add_tensors));
114 
115   std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner;
116   if (device == "cpu") {
117     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
118         model_so_path);
119 #if defined(USE_CUDA) || defined(USE_ROCM)
120   } else if (device == "cuda") {
121     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
122         model_so_path);
123 #endif
124   } else {
125     testing::AssertionFailure() << "unsupported device: " << device;
126   }
127   // By default, buffer #1 get loaded with burned in weights. Correct results.
128   auto actual_output_tensors = runner->run(input_tensors);
129   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
130 
131   // Update with missing map which should throw.
132   EXPECT_THROW(
133       runner->update_constant_buffer(missing_map, false, true),
134       std::runtime_error);
135 
136   // Update random weight to buffer #1.
137   runner->update_constant_buffer(missing_map, false, false);
138   actual_output_tensors = runner->run(input_tensors);
139   if (use_runtime_constant_folding) {
140     // At this moment, this update is applied on the original weight.
141     // The weight being consumed is "folded", so will have no affect.
142     ASSERT_TRUE(
143         torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
144     runner->run_const_fold(/* use_inactive = */ false);
145     actual_output_tensors = runner->run(input_tensors);
146   }
147   ASSERT_FALSE(
148       torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
149 
150   // Update with real map.
151   runner->update_constant_buffer(real_map, false, false);
152   actual_output_tensors = runner->run(input_tensors);
153   if (use_runtime_constant_folding) {
154     runner->run_const_fold(/* use_inactive = */ false);
155   }
156   actual_output_tensors = runner->run(input_tensors);
157   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
158 
159   // Update with full random map.
160   runner->update_constant_buffer(rand_map, false, false);
161   if (use_runtime_constant_folding) {
162     runner->run_const_fold(/* use_inactive = */ false);
163   }
164   actual_output_tensors = runner->run(input_tensors);
165   ASSERT_FALSE(
166       torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
167 }
168 
test_aoti_double_buffering(const std::string & device,bool use_runtime_constant_folding)169 void test_aoti_double_buffering(
170     const std::string& device,
171     bool use_runtime_constant_folding) {
172   torch::NoGradGuard no_grad;
173 
174   std::string data_path =
175       (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt")
176            .string();
177 
178   torch::jit::script::Module data_loader = torch::jit::load(data_path);
179   std::string suffix = use_runtime_constant_folding
180       ? device + "_use_runtime_constant_folding"
181       : device;
182   std::string path_attr = "model_so_path_" + suffix;
183   std::string inputs_attr = "inputs_" + suffix;
184   std::string outputs_attr = "outputs_" + suffix;
185   std::string weights_attr = "w_pre_" + suffix;
186   std::string add_attr = "w_add_" + suffix;
187   const auto& model_so_path = data_loader.attr(path_attr.c_str()).toStringRef();
188   auto input_tensors =
189       data_loader.attr(inputs_attr.c_str()).toTensorList().vec();
190   const auto& ref_output_tensors =
191       data_loader.attr(outputs_attr.c_str()).toTensorList().vec();
192 
193   const auto& weight_tensors =
194       data_loader.attr(weights_attr.c_str()).toTensor();
195   const auto& add_tensors = data_loader.attr(add_attr.c_str()).toTensor();
196 
197   torch::inductor::TensorConstantMap rand_map, real_map;
198   rand_map.emplace("L__self___w_pre", new at::Tensor(at::randn({4, 4})));
199   rand_map.emplace("L__self___w_add", new at::Tensor(at::randn({4, 4})));
200   real_map.emplace("L__self___w_pre", new at::Tensor(weight_tensors));
201   real_map.emplace("L__self___w_add", new at::Tensor(add_tensors));
202 
203   std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner;
204   if (device == "cpu") {
205     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCpu>(
206         model_so_path);
207 #if defined(USE_CUDA) || defined(USE_ROCM)
208   } else if (device == "cuda") {
209     runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
210         model_so_path);
211 #endif
212   } else {
213     testing::AssertionFailure() << "unsupported device: " << device;
214   }
215   // By default, buffer #1 get loaded with burned in weights. Correct results.
216   auto actual_output_tensors = runner->run(input_tensors);
217   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
218 
219   // We update the weights to buffer #2 and activate it. This should still
220   // produce correct result, as it's the real constant map.
221   runner->update_inactive_constant_buffer(real_map);
222   if (use_runtime_constant_folding) {
223     runner->run_const_fold(/* use_inactive = */ true);
224   }
225   runner->swap_constant_buffer();
226   actual_output_tensors = runner->run(input_tensors);
227   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
228 
229   // We update random weights to buffer #1. But do not swap in the weight yet.
230   runner->update_inactive_constant_buffer(rand_map);
231   if (use_runtime_constant_folding) {
232     runner->run_const_fold(/* use_inactive = */ true);
233   }
234   actual_output_tensors = runner->run(input_tensors);
235   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
236 
237   // We swap and activate the weight to buffer #1. This is random weight and
238   // should produce incorrect results.
239   runner->swap_constant_buffer();
240   actual_output_tensors = runner->run(input_tensors);
241   ASSERT_FALSE(
242       torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
243 
244   // Swap back to buffer #2 which is the real constants.
245   runner->swap_constant_buffer();
246   actual_output_tensors = runner->run(input_tensors);
247   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
248 }
249 
250 #if defined(USE_CUDA) || defined(USE_ROCM)
test_aoti_double_buffering_with_tensor_constants()251 void test_aoti_double_buffering_with_tensor_constants() {
252   torch::NoGradGuard no_grad;
253 
254   std::string data_path = (std::filesystem::path(
255                                STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) /
256                                "data_with_tensor_constants.pt")
257                                .string();
258 
259   torch::jit::script::Module data_loader = torch::jit::load(data_path);
260   std::string path_attr = "model_so_path";
261   std::string inputs_attr = "inputs";
262   std::string w_attr = "w";
263   std::string outputs_attr = "outputs";
264   const auto& model_so_path = data_loader.attr(path_attr.c_str()).toStringRef();
265   auto input_tensors =
266       data_loader.attr(inputs_attr.c_str()).toTensorList().vec();
267   const auto& w_tensors = data_loader.attr(w_attr.c_str()).toTensor();
268   const auto& ref_output_tensors =
269       data_loader.attr(outputs_attr.c_str()).toTensorList().vec();
270 
271   torch::inductor::TensorConstantMap real_map;
272   real_map.emplace("L__self___w", new at::Tensor(w_tensors));
273 
274   std::unique_ptr<torch::inductor::AOTIModelContainerRunner> runner;
275   runner = std::make_unique<torch::inductor::AOTIModelContainerRunnerCuda>(
276       model_so_path.c_str());
277 
278   // By default, buffer #1 get loaded with burned in weights. Correct results.
279   auto actual_output_tensors = runner->run(input_tensors);
280   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
281 
282   // We update the weights to buffer #2 and activate it. This should still
283   // produce correct result, since we would have copied the tensor_constants.
284   runner->update_inactive_constant_buffer(real_map);
285   runner->swap_constant_buffer();
286   actual_output_tensors = runner->run(input_tensors);
287   ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0]));
288 }
289 #endif
290 
291 } // namespace
292 
293 namespace torch {
294 namespace aot_inductor {
295 
TEST(AotInductorTest,BasicTestCpu)296 TEST(AotInductorTest, BasicTestCpu) {
297   test_aoti("cpu", false);
298 }
299 
TEST(AotInductorTest,BasicScriptTestCpu)300 TEST(AotInductorTest, BasicScriptTestCpu) {
301   test_aoti_script("cpu");
302 }
303 
304 #ifdef USE_CUDA
TEST(AotInductorTest,BasicTestCuda)305 TEST(AotInductorTest, BasicTestCuda) {
306   test_aoti("cuda", true);
307   test_aoti("cuda", false);
308 }
309 
TEST(AotInductorTest,BasicScriptTestCuda)310 TEST(AotInductorTest, BasicScriptTestCuda) {
311   test_aoti_script("cuda");
312 }
313 
TEST(AotInductorTest,RuntimeUpdateConstantsCuda)314 TEST(AotInductorTest, RuntimeUpdateConstantsCuda) {
315   test_aoti_constants_update("cuda", true);
316 }
317 
TEST(AotInductorTest,UpdateConstantsCuda)318 TEST(AotInductorTest, UpdateConstantsCuda) {
319   test_aoti_constants_update("cuda", false);
320 }
321 
TEST(AotInductorTest,RuntimeUpdateInactiveConstantsCuda)322 TEST(AotInductorTest, RuntimeUpdateInactiveConstantsCuda) {
323   test_aoti_double_buffering("cuda", true);
324 }
325 
TEST(AotInductorTest,UpdateInactiveConstantsCuda)326 TEST(AotInductorTest, UpdateInactiveConstantsCuda) {
327   test_aoti_double_buffering("cuda", false);
328 }
329 
TEST(AotInductorTest,UpdateInactiveConstantsWithTensorConstantsCuda)330 TEST(AotInductorTest, UpdateInactiveConstantsWithTensorConstantsCuda) {
331   test_aoti_double_buffering_with_tensor_constants();
332 }
333 #endif
334 
335 } // namespace aot_inductor
336 } // namespace torch
337