1 #include <gtest/gtest.h>
2
3 #include <test/cpp/jit/test_utils.h>
4 #include <cstdlib>
5 #include <iostream>
6 #include <sstream>
7
8 #include <caffe2/serialize/inline_container.h>
9 #include <torch/csrc/jit/mobile/module.h>
10 #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
11 #include <torch/csrc/jit/serialization/export.h>
12 #include <torch/csrc/jit/serialization/export_bytecode.h>
13 #include <torch/csrc/jit/serialization/import.h>
14 #include <torch/csrc/jit/serialization/import_source.h>
15 #include <torch/script.h>
16 #include <torch/torch.h>
17
18 #include "caffe2/serialize/istream_adapter.h"
19
20 namespace torch {
21 namespace jit {
22
23 namespace {
24
roundtripThroughMobile(const Module & m)25 Module roundtripThroughMobile(const Module& m) {
26 ExtraFilesMap files;
27 std::vector<IValue> constants;
28 jitModuleToPythonCodeAndConstants(m, &files, &constants);
29 CompilationOptions options;
30 mobile::Module mobilem = jitModuleToMobile(m, options);
31 return jitModuleFromSourceAndConstants(
32 mobilem._ivalue(), files, constants, 8);
33 }
34
35 template <class Functor>
expectThrowsEq(Functor && functor,const char * expectedMessage)36 inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
37 try {
38 std::forward<Functor>(functor)();
39 } catch (const Error& e) {
40 EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
41 return;
42 }
43 ADD_FAILURE() << "Expected to throw exception with message \""
44 << expectedMessage << "\" but didn't throw";
45 }
46
47 } // namespace
48
TEST(SerializationTest,ExtraFilesHookPreference)49 TEST(SerializationTest, ExtraFilesHookPreference) {
50 // Tests that an extra file written explicitly has precedence over
51 // extra files written by a hook
52 // TODO: test for the warning, too
53 const auto script = R"JIT(
54 def forward(self):
55 x = torch.rand(5, 5)
56 x = x.mm(x)
57 return x
58 )JIT";
59
60 auto module =
61 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
62 module->define(script);
63 std::ostringstream oss;
64 std::unordered_map<std::string, std::string> extra_files;
65 extra_files["metadata.json"] = "abc";
66 SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
67 return {{"metadata.json", "def"}};
68 });
69 module->save(oss, extra_files);
70 SetExportModuleExtraFilesHook(nullptr);
71
72 std::istringstream iss(oss.str());
73 caffe2::serialize::IStreamAdapter adapter{&iss};
74 std::unordered_map<std::string, std::string> loaded_extra_files;
75 loaded_extra_files["metadata.json"] = "";
76 auto loaded_module = torch::jit::load(iss, torch::kCPU, loaded_extra_files);
77 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
78 }
79
TEST(SerializationTest,ExtraFileHooksNoSecret)80 TEST(SerializationTest, ExtraFileHooksNoSecret) {
81 // no secrets
82 std::stringstream ss;
83 {
84 Module m("__torch__.m");
85 ExtraFilesMap extra;
86 extra["metadata.json"] = "abc";
87 m.save(ss, extra);
88 }
89 ss.seekg(0);
90 {
91 ExtraFilesMap extra;
92 extra["metadata.json"] = "";
93 extra["secret.json"] = "";
94 jit::load(ss, std::nullopt, extra);
95 ASSERT_EQ(extra["metadata.json"], "abc");
96 ASSERT_EQ(extra["secret.json"], "");
97 }
98 }
99
TEST(SerializationTest,ExtraFileHooksWithSecret)100 TEST(SerializationTest, ExtraFileHooksWithSecret) {
101 std::stringstream ss;
102 {
103 SetExportModuleExtraFilesHook([](const Module&) -> ExtraFilesMap {
104 return {{"secret.json", "topsecret"}};
105 });
106 Module m("__torch__.m");
107 ExtraFilesMap extra;
108 extra["metadata.json"] = "abc";
109 m.save(ss, extra);
110 SetExportModuleExtraFilesHook(nullptr);
111 }
112 ss.seekg(0);
113 {
114 ExtraFilesMap extra;
115 extra["metadata.json"] = "";
116 extra["secret.json"] = "";
117 jit::load(ss, std::nullopt, extra);
118 ASSERT_EQ(extra["metadata.json"], "abc");
119 ASSERT_EQ(extra["secret.json"], "topsecret");
120 }
121 }
122
TEST(SerializationTest,TypeTags)123 TEST(SerializationTest, TypeTags) {
124 auto list = c10::List<c10::List<int64_t>>();
125 list.push_back(c10::List<int64_t>({1, 2, 3}));
126 list.push_back(c10::List<int64_t>({4, 5, 6}));
127 auto dict = c10::Dict<std::string, at::Tensor>();
128 dict.insert("Hello", torch::ones({2, 2}));
129 auto dict_list = c10::List<c10::Dict<std::string, at::Tensor>>();
130 for (size_t i = 0; i < 5; i++) {
131 auto another_dict = c10::Dict<std::string, at::Tensor>();
132 another_dict.insert("Hello" + std::to_string(i), torch::ones({2, 2}));
133 dict_list.push_back(another_dict);
134 }
135 auto tuple = std::tuple<int, std::string>(2, "hi");
136 struct TestItem {
137 IValue value;
138 TypePtr expected_type;
139 };
140 std::vector<TestItem> items = {
141 {list, ListType::create(ListType::create(IntType::get()))},
142 {2, IntType::get()},
143 {dict, DictType::create(StringType::get(), TensorType::get())},
144 {dict_list,
145 ListType::create(
146 DictType::create(StringType::get(), TensorType::get()))},
147 {tuple, TupleType::create({IntType::get(), StringType::get()})}};
148 // NOLINTNEXTLINE(performance-for-range-copy)
149 for (auto item : items) {
150 auto bytes = torch::pickle_save(item.value);
151 auto loaded = torch::pickle_load(bytes);
152 ASSERT_TRUE(loaded.type()->isSubtypeOf(*item.expected_type));
153 ASSERT_TRUE(item.expected_type->isSubtypeOf(*loaded.type()));
154 }
155 }
156
TEST(SerializationTest,TestJitStream_CUDA)157 TEST(SerializationTest, TestJitStream_CUDA) {
158 torch::jit::Module model;
159 std::vector<torch::jit::IValue> inputs;
160 // Deserialize the ScriptModule from a file using torch::jit::load().
161 // Load the scripted model. This should have been generated by tests_setup.py
162 // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
163 model = torch::jit::load("saved_stream_model.pt");
164
165 auto output = model.forward(inputs);
166 const auto& list_of_elements = output.toTupleRef().elements();
167 auto is_stream_s = list_of_elements[0].toBool();
168
169 // a,b: These are the two input tensors
170 // c: This is output tensor generated by the operation torch.cat(a,b)
171 auto a = list_of_elements[1].toTensor();
172 auto b = list_of_elements[2].toTensor();
173 auto c = list_of_elements[3].toTensor();
174 // op: this is used to verify if the cat operation produced the same results
175 // as that on the GPU with torch.cat
176 auto op = at::cat({a, b}, 0);
177
178 // Check if the stream is set
179 ASSERT_TRUE(is_stream_s);
180 // Check if the sizes of the outputs (op and c) is same on the GPU and CPU
181 ASSERT_EQ(op.sizes(), c.sizes());
182 // Check if both the output tensors are equal
183 ASSERT_TRUE(op.equal(c));
184 }
185
TEST(TestSourceRoundTrip,UpsampleNearest2d)186 TEST(TestSourceRoundTrip, UpsampleNearest2d) {
187 Module m("m");
188 m.define(R"(
189 def forward(self, input: Tensor, scale:float):
190 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
191 )");
192
193 std::vector<IValue> inputs;
194 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
195 inputs.emplace_back(at::Scalar(2.0));
196 auto ref = m.forward(inputs);
197
198 Module m2 = roundtripThroughMobile(m);
199 auto res = m2.forward(inputs);
200
201 auto resd = res.toTensor();
202 auto refd = ref.toTensor();
203 ASSERT_TRUE(resd.equal(refd));
204 }
205
TEST(TestSourceRoundTrip,CheckAttrAccess)206 TEST(TestSourceRoundTrip, CheckAttrAccess) {
207 Module m("m");
208 m.register_attribute("mobile_optimized", BoolType::get(), true);
209 Module m2 = roundtripThroughMobile(m);
210 bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
211 AT_ASSERT(mobile_optimized);
212 }
213
TEST(TestSourceRoundTrip,MethodInvocation)214 TEST(TestSourceRoundTrip,
215 MethodInvocation) { // NOLINT (use =delete in gtest)
216 const std::vector<std::string> test_programs{
217 // test invoking a method with default parameter
218 R"(
219 def test_func(self, x, b : int = 4):
220 return self.foo + x + b
221 )",
222 // inner method call with default parameter (gets inlined)
223 R"(
224 def add_with_default_arg(self, x, b : int = 4):
225 return self.foo + x + b
226 def test_func(self, x):
227 return self.add_with_default_arg(x) # invoke method w/ default arg
228 )",
229 // simple method call
230 R"(
231 def test_func(self, x):
232 b = 4
233 return self.foo + x + b
234 )",
235 };
236 for (const auto& test_program : test_programs) {
237 Module m("m");
238 m.register_parameter("foo", torch::ones({}), false);
239 m.define(test_program);
240
241 const int fortyTwo = 42; // (keep linter happy)
242 auto minput = fortyTwo * torch::ones({});
243 auto ref = m.run_method("test_func", minput);
244
245 Module m2 = roundtripThroughMobile(m);
246 const auto& test_func = m2.get_method("test_func");
247 IValue res;
248 for (int i = 0; i < 3; ++i) {
249 res = test_func({minput});
250 }
251
252 auto resd = res.toTensor().item<float>();
253 auto refd = ref.toTensor().item<float>();
254 AT_ASSERT(resd == refd);
255 }
256 }
257
TEST(SerializationTest,ParentDirNotExist)258 TEST(SerializationTest, ParentDirNotExist) {
259 expectThrowsEq(
260 []() {
261 auto t = torch::nn::Linear(5, 5);
262 torch::save(t, "./doesnotexist/file.pt");
263 },
264 "Parent directory ./doesnotexist does not exist.");
265 }
266
267 #ifdef WIN32
TEST(SerializationTest,WindowsDrivePathTest)268 TEST(SerializationTest, WindowsDrivePathTest) {
269 // "ZZZ" is typically not a valid drive letter.
270 // We expect to see "ZZZ:\\" or "ZZZ:/" in the error message.
271 // Note: slash should be included for the drive letter parent in Windows.
272 expectThrowsEq(
273 []() {
274 auto t = torch::nn::Linear(5, 5);
275 torch::save(t, "ZZZ:\\file.pt");
276 },
277 "Parent directory ZZZ:\\ does not exist.");
278 expectThrowsEq(
279 []() {
280 auto t = torch::nn::Linear(5, 5);
281 torch::save(t, "ZZZ:/file.pt");
282 },
283 "Parent directory ZZZ:/ does not exist.");
284 }
285
TEST(SerializationTest,WindowsTempPathTest)286 TEST(SerializationTest, WindowsTempPathTest) {
287 // Test for verifying file saving and loading in the temporary folder
288 std::string temp_dir = std::getenv("TEMP");
289 std::string file_path = temp_dir + "/file.pt";
290 auto t1 = torch::tensor(1.0);
291 torch::save(t1, file_path);
292 torch::Tensor t2;
293 torch::load(t2, file_path);
294 ASSERT_TRUE(t1.allclose(t2, 0.0, 0.0));
295 }
296 #endif
297
TEST(SerializationTest,CalculateNecessaryArgsTest)298 TEST(SerializationTest, CalculateNecessaryArgsTest) {
299 auto schema = torch::schema(
300 "sync_stream(int stream_id = -1) -> ()",
301 c10::AliasAnalysisKind::CONSERVATIVE);
302
303 auto graph = std::make_shared<Graph>();
304 auto one_val = graph->insertConstant(-1);
305 auto necessary = CalculateNecessaryArgs(schema.arguments(), {one_val}, true);
306 EXPECT_EQ(0, necessary.first);
307 EXPECT_EQ(0, necessary.second);
308 }
309
TEST(TestSaveLoad,LoadWithoutDebugInfo)310 TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
311 Module m("m");
312 m.register_parameter("foo", torch::ones({}), false);
313 m.define(
314 R"(
315 def test_func(self, x):
316 b = 4
317 return self.foo + x + b
318 )");
319 m.define(
320 R"(
321 def exception(self):
322 assert False, "message"
323 )");
324 std::stringstream ss;
325 m.save(ss);
326 ss.seekg(0);
327 caffe2::serialize::PyTorchStreamReader reader(&ss);
328 reader.setShouldLoadDebugSymbol(true);
329 EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl"));
330 reader.setShouldLoadDebugSymbol(false);
331 EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl"));
332 ss.seekg(0);
333 Module m2 = torch::jit::load(ss);
334 std::string error_msg = R"(
335 def exception(self):
336 assert False, "message"
337 ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)";
338 ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg);
339
340 ss.seekg(0);
341 // NO DEBUG trace so error message points to torchscript generated
342 // source instead of original python source.
343 std::string error2 = R"(
344 def exception(self: __torch__.m) -> NoneType:
345 _0 = uninitialized(NoneType)
346 ops.prim.RaiseException("AssertionError: message")
347 ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
348 return _0
349 )";
350 Module m3 = torch::jit::load(ss, std::nullopt, false);
351 ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
352 }
353
TEST(SerializationTest,TestPickleAppend)354 TEST(SerializationTest, TestPickleAppend) {
355 auto data = std::vector<char>({'\x80', char(2), ']', 'K', char(2), 'a', '.'});
356
357 torch::IValue actual = torch::jit::unpickle(data.data(), data.size());
358
359 torch::IValue expected = c10::impl::GenericList(at::AnyType::get());
360 expected.toList().push_back(2);
361 ASSERT_EQ(expected, actual);
362 }
363
364 } // namespace jit
365 } // namespace torch
366