xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_interpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3 
4 #include <ATen/Parallel.h>
5 #include <c10/core/DeviceType.h>
6 #include <test/cpp/jit/test_utils.h>
7 #include <torch/csrc/jit/runtime/instruction.h>
8 #include <torch/jit.h>
9 #include <torch/script.h>
10 #include <torch/torch.h>
11 
12 namespace torch {
13 namespace jit {
14 
15 class TypeCheckTest : public ::testing::Test {
16  protected:
TypeCheckTest()17   TypeCheckTest() : interp(makeInterp()) {}
18 
19   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
20   InterpreterState interp;
21 
22  private:
makeInterp()23   static InterpreterState makeInterp() {
24     auto graph = std::make_shared<Graph>();
25     std::unordered_map<std::string, Value*> vmap;
26     parseIR(
27         R"IR(
28 graph(%a.1 : Tensor,
29       %b.1 : Tensor):
30   %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
31   return (%t0, %t1, %type_matched)
32   )IR",
33         &*graph,
34         vmap);
35 
36     Code function(graph, "");
37     return InterpreterState(function);
38   }
39 };
40 
TEST_F(TypeCheckTest,MatchingType)41 TEST_F(TypeCheckTest, MatchingType) {
42   // TypeCheck yields to true! Shape, grad and device matches.
43   auto a = at::zeros({2, 2}, at::kFloat);
44   auto b = at::ones({3, 3}, at::kFloat);
45   a.set_requires_grad(true);
46   a = a.to(at::kCPU);
47   std::vector<IValue> stack({a, b});
48   interp.run(stack);
49   ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
50   ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
51   ASSERT_TRUE(stack[2].toBool());
52 }
53 
TEST_F(TypeCheckTest,SizeMismatch)54 TEST_F(TypeCheckTest, SizeMismatch) {
55   auto a = at::zeros({2, 2}, at::kFloat);
56   auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
57   a.set_requires_grad(true);
58   a = a.to(at::kCPU);
59   std::vector<IValue> stack({a, b});
60   interp.run(stack);
61   ASSERT_FALSE(stack[2].toBool());
62 }
63 
TEST_F(TypeCheckTest,GradientMismatch)64 TEST_F(TypeCheckTest, GradientMismatch) {
65   auto a = at::zeros({2, 2}, at::kFloat);
66   auto b = at::ones({3, 3}, at::kFloat);
67   a = a.to(at::kCPU);
68   a.set_requires_grad(false); // Gradient mismatch
69   std::vector<IValue> stack({a, b});
70   interp.run(stack);
71   ASSERT_FALSE(stack[2].toBool());
72 }
73 
TEST_F(TypeCheckTest,ScalarTypeMismatch)74 TEST_F(TypeCheckTest, ScalarTypeMismatch) {
75   auto a = at::zeros({2, 2}, at::kFloat);
76   auto b = at::ones({3, 3}, at::kFloat);
77   a = a.to(at::kCPU);
78   a.set_requires_grad(true);
79   a = a.to(at::kInt); // Scalar type mismatch
80   std::vector<IValue> stack({a, b});
81   interp.run(stack);
82   ASSERT_FALSE(stack[2].toBool());
83 }
84 
TEST_F(TypeCheckTest,DeviceMismatch_CUDA)85 TEST_F(TypeCheckTest, DeviceMismatch_CUDA) {
86   auto a = at::zeros({2, 2}, at::kFloat);
87   auto b = at::ones({3, 3}, at::kFloat);
88   a.set_requires_grad(true);
89   a = a.to(at::kCUDA); // Device mismatch
90   std::vector<IValue> stack({a, b});
91   interp.run(stack);
92   ASSERT_FALSE(stack[2].toBool());
93 }
94 
95 // TODO: These tests weren't doing anything.
96 // TEST(TypeCheckErrorTest, EmptyCheckRaises) {
97 //   // Test empty Typecheck raises an internal assertion
98 //   auto graph = std::make_shared<Graph>();
99 //   std::unordered_map<std::string, Value*> vmap;
100 //   EXPECT_ANY_THROW(parseIR(
101 //       R"IR(
102 // graph(%a.1 : Tensor,
103 //       %b.1 : Tensor):
104 //   %type_matched : bool = prim::TypeCheck()
105 //   return (%type_matched)
106 //   )IR",
107 //       &*graph,
108 //       vmap));
109 // }
110 
111 // TODO: These tests weren't doing anything.
112 // TEST(TypeCheckErrorTest, WrongInputOutputCountRaises) {
113 //   // Test for assertion if num_inputs + 1 != num_outputs
114 //   auto graph = std::make_shared<Graph>();
115 //   std::unordered_map<std::string, Value*> vmap;
116 //   EXPECT_ANY_THROW(parseIR(
117 //       R"IR(
118 // graph(%a.1 : Tensor,
119 //       %b.1 : Tensor):
120 //   %type_matched : bool = prim::TypeCheck(%a.1)
121 //   return (%type_matched)
122 //   )IR",
123 //       &*graph,
124 //       vmap));
125 // }
126 
TEST(InterpreterTest,Basic_CUDA)127 TEST(InterpreterTest, Basic_CUDA) {
128   constexpr int batch_size = 4;
129   constexpr int input_size = 256;
130   constexpr int seq_len = 32;
131 
132   int hidden_size = 2 * input_size;
133 
134   auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
135   auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
136   auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
137   auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
138   auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
139 
140   auto lstm_g = build_lstm();
141   Code lstm_function(lstm_g, "");
142   InterpreterState lstm_interp(lstm_function);
143   auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
144   std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
145 
146   ASSERT_TRUE(exactlyEqual(outputs[0], hx));
147   ASSERT_TRUE(exactlyEqual(outputs[1], cx));
148 }
149 
TEST(InterpreterTest,IgnorableArgsInSchema)150 TEST(InterpreterTest, IgnorableArgsInSchema) {
151   auto graph = build_mobile_export_analysis_graph();
152   MobileCode function(graph, "");
153   auto op_to_specified_args = function.op_to_num_specified_args();
154   ASSERT_TRUE(op_to_specified_args.size() == 2);
155   ASSERT_TRUE(op_to_specified_args["aten::slice.Tensor"] == 4);
156   ASSERT_TRUE(op_to_specified_args["aten::slice.str"] == 4);
157   auto graph_vararg = build_mobile_export_analysis_graph_with_vararg();
158   MobileCode function_vararg(graph_vararg, "");
159   auto op_to_specified_args_vararg = function_vararg.op_to_num_specified_args();
160   // should never register it
161   ASSERT_TRUE(
162       op_to_specified_args_vararg.find("prim::tolist") ==
163       op_to_specified_args_vararg.end());
164 
165   auto graph_nested = build_mobile_export_analysis_graph_nested();
166   MobileCode function_nested(graph_nested, "");
167   auto op_to_specified_args_nested = function_nested.op_to_num_specified_args();
168   ASSERT_TRUE(op_to_specified_args_nested["aten::slice.Tensor"] == 4);
169   ASSERT_TRUE(op_to_specified_args_nested["aten::slice.str"] == 4);
170 
171   auto graph_non_const = build_mobile_export_analysis_graph_non_const();
172   MobileCode function_non_const(graph_non_const, "");
173   auto op_to_specified_args_non_const =
174       function_non_const.op_to_num_specified_args();
175   ASSERT_TRUE(op_to_specified_args_non_const["aten::conv2d"] == 6);
176 }
177 
TEST(InterpreterTest,IgnorableArgsInSchemaWithOut)178 TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) {
179   auto graph = build_mobile_export_with_out();
180   MobileCode function(graph, "");
181   auto op_to_specified_args = function.op_to_num_specified_args();
182   ASSERT_TRUE(op_to_specified_args.size() == 1);
183   // this should be 3 when the add_out flag is set to True
184   ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 3);
185 }
186 
TEST(InterpreterTest,runAsyncBasicTest)187 TEST(InterpreterTest, runAsyncBasicTest) {
188   /*
189   TODO: there are some problem with C++ parsing script program involving
190   fork. Use the test module below for now.
191   issue about this: github.com/pytorch/pytorch/issues/46368
192   The test module file is generated by following:
193     class DemoModule(torch.nn.Module):
194       def forward(self):
195         r1 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
196         r2 = torch.jit.fork(torch.mm, torch.rand(100,100),torch.rand(100,100))
197         return r1.wait() + r2.wait()
198   demo = DemoModule()
199   torch.jit.save(torch.jit.script(demo), 'test_interpreter_async.pt')
200   */
201   std::string filePath(__FILE__);
202   auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1);
203   testModelFile.append("test_interpreter_async.pt");
204   auto model = load(testModelFile);
205   auto graph = model.get_method("forward").graph();
206   Code function(graph, "");
207   auto asyncCounter = 0;
208   std::mutex mtx;
209   // a dummy executor which actually use at::launch, but add up a counter
210   auto launcher = [&](std::function<void()> f) {
211     mtx.lock();
212     ++asyncCounter;
213     mtx.unlock();
214     at::launch(f);
215   };
216   std::vector<IValue> stack;
217   // NOLINTNEXTLINE(modernize-use-emplace)
218   stack.push_back(model._ivalue());
219   InterpreterState interp(function, launcher);
220   interp.runAsync(stack)->wait();
221   ASSERT_TRUE(asyncCounter > 0);
222 }
223 
TEST(EnableRethrowCaughtExceptionTest,EnableRethrowCaughtExceptionTestRethrowsCaughtException)224 TEST(
225     EnableRethrowCaughtExceptionTest,
226     EnableRethrowCaughtExceptionTestRethrowsCaughtException) {
227   auto graph = std::make_shared<Graph>();
228   std::unordered_map<std::string, Value*> vmap;
229   parseIR(
230       R"IR(
231 graph(%0 : Tensor,
232       %1 : Tensor):
233   %2 : int = prim::Constant[value=2]()
234   %3 : Tensor = aten::add(%0, %1, %2)
235   return (%3)
236   )IR",
237       &*graph,
238       vmap);
239   Code function(graph, "");
240   InterpreterState interp = InterpreterState(function);
241   auto a = at::zeros({2, 2}, at::kFloat);
242   auto b = at::ones({2, 3}, at::kFloat);
243   a.set_requires_grad(true);
244   a = a.to(at::kCPU);
245   std::vector<IValue> stack({a, b});
246 
247   bool original_flag_value = FLAGS_torch_jit_enable_rethrow_caught_exception;
248   bool exception_handled = false;
249   try {
250     FLAGS_torch_jit_enable_rethrow_caught_exception = false;
251     interp.run(stack);
252   } catch (std::runtime_error& e) {
253     exception_handled = true;
254     std::string exception_msg = e.what();
255     EXPECT_THAT(
256         exception_msg,
257         ::testing::HasSubstr("%3 : Tensor = aten::add(%0, %1, %2)"));
258     EXPECT_THAT(
259         exception_msg,
260         ::testing::HasSubstr(
261             "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1"));
262   }
263   EXPECT_TRUE(exception_handled);
264 
265   exception_handled = false;
266   try {
267     FLAGS_torch_jit_enable_rethrow_caught_exception = true;
268     interp.run(stack);
269   } catch (c10::Error& e) {
270     exception_handled = true;
271     std::string exception_msg = e.what_without_backtrace();
272     EXPECT_STREQ(
273         exception_msg.c_str(),
274         "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
275   }
276   EXPECT_TRUE(exception_handled);
277 
278   FLAGS_torch_jit_enable_rethrow_caught_exception = true;
279   c10::intrusive_ptr<Future> future = interp.runAsync(stack);
280   future->wait();
281   ASSERT_TRUE(future->completed());
282   ASSERT_TRUE(future->hasError());
283   try {
284     std::rethrow_exception(future->exception_ptr());
285   } catch (c10::Error& e) {
286     std::string exception_msg = e.what_without_backtrace();
287     EXPECT_STREQ(
288         exception_msg.c_str(),
289         "The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1");
290   }
291 
292   FLAGS_torch_jit_enable_rethrow_caught_exception = original_flag_value;
293 }
294 
295 } // namespace jit
296 } // namespace torch
297