1 #include <gtest/gtest.h>
2
3 #include <c10/util/flat_hash_map.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/tempfile.h>
6
7 #include <torch/torch.h>
8
9 #include <test/cpp/api/support.h>
10
11 #include <cstdio>
12 #include <memory>
13 #include <sstream>
14 #include <string>
15 #include <vector>
16
17 using namespace torch::test;
18 using namespace torch::nn;
19 using namespace torch::optim;
20
21 namespace {
xor_model()22 Sequential xor_model() {
23 return Sequential(
24 Linear(2, 8),
25 Functional(at::sigmoid),
26 Linear(8, 1),
27 Functional(at::sigmoid));
28 }
29
save_and_load(torch::Tensor input)30 torch::Tensor save_and_load(torch::Tensor input) {
31 std::stringstream stream;
32 torch::save(input, stream);
33 torch::Tensor tensor;
34 torch::load(tensor, stream);
35 return tensor;
36 }
37 } // namespace
38
39 template <typename DerivedOptions>
is_optimizer_param_group_equal(const OptimizerParamGroup & lhs,const OptimizerParamGroup & rhs)40 void is_optimizer_param_group_equal(
41 const OptimizerParamGroup& lhs,
42 const OptimizerParamGroup& rhs) {
43 const auto& lhs_params = lhs.params();
44 const auto& rhs_params = rhs.params();
45
46 ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47 for (const auto j : c10::irange(lhs_params.size())) {
48 ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
49 }
50 ASSERT_TRUE(
51 static_cast<const DerivedOptions&>(lhs.options()) ==
52 static_cast<const DerivedOptions&>(rhs.options()));
53 }
54
55 template <typename DerivedOptimizerParamState>
is_optimizer_state_equal(const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & lhs_state,const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & rhs_state)56 void is_optimizer_state_equal(
57 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
58 lhs_state,
59 const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
60 rhs_state) {
61 ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62 for (const auto& value : lhs_state) {
63 auto found = rhs_state.find(value.first);
64 ASSERT_TRUE(found != rhs_state.end());
65 const DerivedOptimizerParamState& lhs_curr_state =
66 static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67 const DerivedOptimizerParamState& rhs_curr_state =
68 static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69 ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
70 }
71 }
72
73 template <
74 typename OptimizerClass,
75 typename DerivedOptimizerOptions,
76 typename DerivedOptimizerParamState>
test_serialize_optimizer(DerivedOptimizerOptions options,bool only_has_global_state=false)77 void test_serialize_optimizer(
78 DerivedOptimizerOptions options,
79 bool only_has_global_state = false) {
80 torch::manual_seed(0);
81 auto model1 = Linear(5, 2);
82 auto model2 = Linear(5, 2);
83 auto model3 = Linear(5, 2);
84
85 // Models 1, 2, 3 will have the same parameters.
86 auto model_tempfile = c10::make_tempfile();
87 torch::save(model1, model_tempfile.name);
88 torch::load(model2, model_tempfile.name);
89 torch::load(model3, model_tempfile.name);
90
91 auto param1 = model1->named_parameters();
92 auto param2 = model2->named_parameters();
93 auto param3 = model3->named_parameters();
94 for (const auto& p : param1) {
95 ASSERT_TRUE(p->allclose(param2[p.key()]));
96 ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
97 }
98 // Make some optimizers
99 auto optim1 = OptimizerClass(
100 {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101 auto optim2 = OptimizerClass(model2->parameters(), options);
102 auto optim2_2 = OptimizerClass(model2->parameters(), options);
103 auto optim3 = OptimizerClass(model3->parameters(), options);
104 auto optim3_2 = OptimizerClass(model3->parameters(), options);
105 for (auto& param_group : optim3_2.param_groups()) {
106 const double lr = param_group.options().get_lr();
107 // change the learning rate, which will be overwritten by the loading
108 // otherwise, test cannot check if options are saved and loaded correctly
109 param_group.options().set_lr(lr + 0.01);
110 }
111
112 auto x = torch::ones({10, 5});
113
114 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
115 optimizer.zero_grad();
116 auto y = model->forward(x).sum();
117 y.backward();
118 auto closure = []() { return torch::tensor({10}); };
119 optimizer.step(closure);
120 };
121
122 // Do 2 steps of model1
123 step(optim1, model1);
124 step(optim1, model1);
125
126 // Do 2 steps of model 2 without saving the optimizer
127 step(optim2, model2);
128 step(optim2_2, model2);
129
130 // Do 1 step of model 3
131 step(optim3, model3);
132
133 // save the optimizer
134 auto optim_tempfile = c10::make_tempfile();
135 torch::save(optim3, optim_tempfile.name);
136 torch::load(optim3_2, optim_tempfile.name);
137
138 auto& optim3_2_param_groups = optim3_2.param_groups();
139 auto& optim3_param_groups = optim3.param_groups();
140 auto& optim3_2_state = optim3_2.state();
141 auto& optim3_state = optim3.state();
142
143 // optim3_2 and optim1 should have param_groups and state of size 1 and
144 // state_size respectively
145 ASSERT_TRUE(optim3_2_param_groups.size() == 1);
146 // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
147 // global state
148 unsigned state_size = only_has_global_state ? 1 : 2;
149 ASSERT_TRUE(optim3_2_state.size() == state_size);
150
151 // optim3_2 and optim1 should have param_groups and state of same size
152 ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
153 ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
154
155 // checking correctness of serialization logic for optimizer.param_groups_ and
156 // optimizer.state_
157 for (const auto i : c10::irange(optim3_2_param_groups.size())) {
158 is_optimizer_param_group_equal<DerivedOptimizerOptions>(
159 optim3_2_param_groups[i], optim3_param_groups[i]);
160 is_optimizer_state_equal<DerivedOptimizerParamState>(
161 optim3_2_state, optim3_state);
162 }
163
164 // Do step2 for model 3
165 step(optim3_2, model3);
166
167 param1 = model1->named_parameters();
168 param2 = model2->named_parameters();
169 param3 = model3->named_parameters();
170 for (const auto& p : param1) {
171 const auto& name = p.key();
172 // Model 1 and 3 should be the same
173 ASSERT_TRUE(
174 param1[name].norm().item<float>() == param3[name].norm().item<float>());
175 ASSERT_TRUE(
176 param1[name].norm().item<float>() != param2[name].norm().item<float>());
177 }
178 }
179
180 /// Utility function to save a value of `int64_t` type.
write_int_value(torch::serialize::OutputArchive & archive,const std::string & key,const int64_t & value)181 void write_int_value(
182 torch::serialize::OutputArchive& archive,
183 const std::string& key,
184 const int64_t& value) {
185 archive.write(key, c10::IValue(value));
186 }
187 // Utility function to save a vector of buffers.
188 template <typename BufferContainer>
write_tensors_to_archive(torch::serialize::OutputArchive & archive,const std::string & key,const BufferContainer & buffers)189 void write_tensors_to_archive(
190 torch::serialize::OutputArchive& archive,
191 const std::string& key,
192 const BufferContainer& buffers) {
193 archive.write(
194 key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
195 for (const auto index : c10::irange(buffers.size())) {
196 archive.write(
197 key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
198 }
199 }
200
201 // Utility function to save a vector of step buffers.
write_step_buffers(torch::serialize::OutputArchive & archive,const std::string & key,const std::vector<int64_t> & steps)202 void write_step_buffers(
203 torch::serialize::OutputArchive& archive,
204 const std::string& key,
205 const std::vector<int64_t>& steps) {
206 std::vector<torch::Tensor> tensors;
207 tensors.reserve(steps.size());
208 for (const auto& step : steps) {
209 tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
210 }
211 write_tensors_to_archive(archive, key, tensors);
212 }
213
214 #define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
215 { \
216 WarningCapture warnings; \
217 funcname(optimizer, filename); \
218 ASSERT_EQ( \
219 count_substr_occurrences(warnings.str(), "old serialization"), 1); \
220 }
221
TEST(SerializeTest,KeysFunc)222 TEST(SerializeTest, KeysFunc) {
223 auto tempfile = c10::make_tempfile();
224 torch::serialize::OutputArchive output_archive;
225 for (const auto i : c10::irange(3)) {
226 output_archive.write(
227 "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
228 }
229 output_archive.save_to(tempfile.name);
230 torch::serialize::InputArchive input_archive;
231 input_archive.load_from(tempfile.name);
232 std::vector<std::string> keys = input_archive.keys();
233 ASSERT_EQ(keys.size(), 3);
234 for (const auto i : c10::irange(keys.size())) {
235 ASSERT_EQ(keys[i], "element/" + std::to_string(i));
236 }
237 }
238
TEST(SerializeTest,TryReadFunc)239 TEST(SerializeTest, TryReadFunc) {
240 auto tempfile = c10::make_tempfile();
241 torch::serialize::OutputArchive output_archive;
242 for (const auto i : c10::irange(3)) {
243 output_archive.write(
244 "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
245 }
246 output_archive.save_to(tempfile.name);
247 torch::serialize::InputArchive input_archive;
248 input_archive.load_from(tempfile.name);
249 c10::IValue ivalue;
250 ASSERT_FALSE(input_archive.try_read("1", ivalue));
251 ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
252 ASSERT_EQ(ivalue.toInt(), 1);
253 }
254
TEST(SerializeTest,Basic)255 TEST(SerializeTest, Basic) {
256 torch::manual_seed(0);
257
258 auto x = torch::randn({5, 5});
259 auto y = save_and_load(x);
260
261 ASSERT_TRUE(y.defined());
262 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
263 ASSERT_TRUE(x.allclose(y));
264 }
265
TEST(SerializeTest,MathBits)266 TEST(SerializeTest, MathBits) {
267 torch::manual_seed(0);
268
269 auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
270 auto x = torch::randn({5, 5}, options);
271 {
272 auto expected = torch::conj(x);
273 auto actual = save_and_load(expected);
274
275 ASSERT_TRUE(actual.defined());
276 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
277 ASSERT_TRUE(actual.allclose(expected));
278 }
279
280 {
281 auto expected = torch::_neg_view(x);
282 auto actual = save_and_load(expected);
283
284 ASSERT_TRUE(actual.defined());
285 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
286 ASSERT_TRUE(actual.allclose(expected));
287 }
288
289 {
290 auto expected = torch::conj(torch::_neg_view(x));
291 auto actual = save_and_load(expected);
292
293 ASSERT_TRUE(actual.defined());
294 ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
295 ASSERT_TRUE(actual.allclose(expected));
296 }
297
298 {
299 // We don't support serializing `ZeroTensor` as it is not public facing yet.
300 // If in future, `ZeroTensor` serialization is supported, this test should
301 // start failing!
302 auto t = torch::_efficientzerotensor({5, 5});
303 ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
304 }
305 }
306
TEST(SerializeTest,BasicToFile)307 TEST(SerializeTest, BasicToFile) {
308 torch::manual_seed(0);
309
310 auto x = torch::randn({5, 5});
311
312 auto tempfile = c10::make_tempfile();
313 torch::save(x, tempfile.name);
314
315 torch::Tensor y;
316 torch::load(y, tempfile.name);
317
318 ASSERT_TRUE(y.defined());
319 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
320 ASSERT_TRUE(x.allclose(y));
321 }
322
TEST(SerializeTest,BasicViaFunc)323 TEST(SerializeTest, BasicViaFunc) {
324 torch::manual_seed(0);
325
326 auto x = torch::randn({5, 5});
327
328 std::string serialized;
329 torch::save(x, [&](const void* buf, size_t n) {
330 serialized.append(reinterpret_cast<const char*>(buf), n);
331 return n;
332 });
333 torch::Tensor y;
334 torch::load(y, serialized.data(), serialized.size());
335
336 ASSERT_TRUE(y.defined());
337 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
338 ASSERT_TRUE(x.allclose(y));
339
340 torch::Tensor z;
341 torch::load(
342 z,
343 [&](uint64_t pos, void* buf, size_t n) -> size_t {
344 if (pos >= serialized.size())
345 return 0;
346 size_t nbytes =
347 std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
348 memcpy(buf, serialized.data() + pos, nbytes);
349 return nbytes;
350 },
351 [&]() -> size_t { return serialized.size(); });
352 ASSERT_TRUE(z.defined());
353 ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
354 ASSERT_TRUE(x.allclose(z));
355 }
356
TEST(SerializeTest,Resized)357 TEST(SerializeTest, Resized) {
358 torch::manual_seed(0);
359
360 auto x = torch::randn({11, 5});
361 x.resize_({5, 5});
362 auto y = save_and_load(x);
363
364 ASSERT_TRUE(y.defined());
365 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
366 ASSERT_TRUE(x.allclose(y));
367 }
368
TEST(SerializeTest,Sliced)369 TEST(SerializeTest, Sliced) {
370 torch::manual_seed(0);
371
372 auto x = torch::randn({11, 5});
373 x = x.slice(0, 1, 5);
374 auto y = save_and_load(x);
375
376 ASSERT_TRUE(y.defined());
377 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
378 ASSERT_TRUE(x.allclose(y));
379 }
380
TEST(SerializeTest,NonContiguous)381 TEST(SerializeTest, NonContiguous) {
382 torch::manual_seed(0);
383
384 auto x = torch::randn({11, 5});
385 x = x.slice(1, 1, 4);
386 auto y = save_and_load(x);
387
388 ASSERT_TRUE(y.defined());
389 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
390 ASSERT_TRUE(x.allclose(y));
391 }
392
TEST(SerializeTest,ErrorOnMissingKey)393 TEST(SerializeTest, ErrorOnMissingKey) {
394 struct B : torch::nn::Module {
395 B(const std::string& name_c) {
396 register_buffer(name_c, torch::ones(5, torch::kFloat));
397 }
398 };
399 struct A : torch::nn::Module {
400 A(const std::string& name_b, const std::string& name_c) {
401 register_module(name_b, std::make_shared<B>(name_c));
402 }
403 };
404 struct M : torch::nn::Module {
405 M(const std::string& name_a,
406 const std::string& name_b,
407 const std::string& name_c) {
408 register_module(name_a, std::make_shared<A>(name_b, name_c));
409 }
410 };
411
412 // create a hierarchy of models with names differing below the top level
413 auto model1 = std::make_shared<M>("a", "b", "c");
414 auto model2 = std::make_shared<M>("a", "b", "x");
415 auto model3 = std::make_shared<M>("a", "x", "c");
416
417 std::stringstream stream;
418 torch::save(model1, stream);
419 // We want the errors to contain hierarchy information, too.
420 ASSERT_THROWS_WITH(
421 torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
422 stream.seekg(0, stream.beg);
423 ASSERT_THROWS_WITH(
424 torch::load(model3, stream), "No such serialized submodule: 'a.x'");
425 }
426
TEST(SerializeTest,XOR)427 TEST(SerializeTest, XOR) {
428 // We better be able to save and load an XOR model!
429 auto getLoss = [](Sequential model, uint32_t batch_size) {
430 auto inputs = torch::empty({batch_size, 2});
431 auto labels = torch::empty({batch_size});
432 for (const auto i : c10::irange(batch_size)) {
433 inputs[i] = torch::randint(2, {2}, torch::kInt64);
434 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
435 }
436 auto x = model->forward<torch::Tensor>(inputs);
437 return torch::binary_cross_entropy(x, labels);
438 };
439
440 auto model = xor_model();
441 auto model2 = xor_model();
442 auto model3 = xor_model();
443 auto optimizer = torch::optim::SGD(
444 model->parameters(),
445 torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
446 1e-6));
447
448 float running_loss = 1;
449 int epoch = 0;
450 while (running_loss > 0.1) {
451 torch::Tensor loss = getLoss(model, 4);
452 optimizer.zero_grad();
453 loss.backward();
454 optimizer.step();
455
456 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
457 running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
458 ASSERT_LT(epoch, 3000);
459 epoch++;
460 }
461
462 auto tempfile = c10::make_tempfile();
463 torch::save(model, tempfile.name);
464 torch::load(model2, tempfile.name);
465
466 auto loss = getLoss(model2, 100);
467 ASSERT_LT(loss.item<float>(), 0.1);
468 }
469
TEST(SerializeTest,Optim)470 TEST(SerializeTest, Optim) {
471 auto model1 = Linear(5, 2);
472 auto model2 = Linear(5, 2);
473 auto model3 = Linear(5, 2);
474
475 // Models 1, 2, 3 will have the same parameters.
476 auto model_tempfile = c10::make_tempfile();
477 torch::save(model1, model_tempfile.name);
478 torch::load(model2, model_tempfile.name);
479 torch::load(model3, model_tempfile.name);
480
481 auto param1 = model1->named_parameters();
482 auto param2 = model2->named_parameters();
483 auto param3 = model3->named_parameters();
484 for (const auto& p : param1) {
485 ASSERT_TRUE(p->allclose(param2[p.key()]));
486 ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
487 }
488
489 // Make some optimizers with momentum (and thus state)
490 auto optim1 = torch::optim::SGD(
491 model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492 auto optim2 = torch::optim::SGD(
493 model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494 auto optim2_2 = torch::optim::SGD(
495 model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
496 auto optim3 = torch::optim::SGD(
497 model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
498 auto optim3_2 = torch::optim::SGD(
499 model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
500
501 auto x = torch::ones({10, 5});
502
503 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
504 optimizer.zero_grad();
505 auto y = model->forward(x).sum();
506 y.backward();
507 optimizer.step();
508 };
509
510 // Do 2 steps of model1
511 step(optim1, model1);
512 step(optim1, model1);
513
514 // Do 2 steps of model 2 without saving the optimizer
515 step(optim2, model2);
516 step(optim2_2, model2);
517
518 // Do 2 steps of model 3 while saving the optimizer
519 step(optim3, model3);
520
521 auto optim_tempfile = c10::make_tempfile();
522 torch::save(optim3, optim_tempfile.name);
523 torch::load(optim3_2, optim_tempfile.name);
524 step(optim3_2, model3);
525
526 param1 = model1->named_parameters();
527 param2 = model2->named_parameters();
528 param3 = model3->named_parameters();
529 for (const auto& p : param1) {
530 const auto& name = p.key();
531 // Model 1 and 3 should be the same
532 ASSERT_TRUE(
533 param1[name].norm().item<float>() == param3[name].norm().item<float>());
534 ASSERT_TRUE(
535 param1[name].norm().item<float>() != param2[name].norm().item<float>());
536 }
537 }
538
TEST(SerializeTest,Optim_Adagrad)539 TEST(SerializeTest, Optim_Adagrad) {
540 test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
541 AdagradOptions(1e-1));
542
543 // bc compatibility check
544 auto model1 = Linear(5, 2);
545 auto optim1 = torch::optim::Adagrad(
546 model1->parameters(), torch::optim::AdagradOptions(1e-1));
547
548 auto x = torch::ones({10, 5});
549 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
550 optimizer.zero_grad();
551 auto y = model->forward(x).sum();
552 y.backward();
553 optimizer.step();
554 };
555 step(optim1, model1);
556 auto optim1_2 =
557 Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
558
559 // fill up with optim1 sum_buffers
560 std::vector<torch::Tensor> sum_buffers;
561 // fill up with optim1 state_buffers
562 std::vector<int64_t> step_buffers;
563 const auto& params_ = optim1.param_groups()[0].params();
564 const auto& optim1_state = optim1.state();
565 for (const auto& param : params_) {
566 auto key_ = param.unsafeGetTensorImpl();
567 const AdagradParamState& curr_state_ =
568 static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
569 sum_buffers.emplace_back(curr_state_.sum());
570 step_buffers.emplace_back(curr_state_.step());
571 }
572 // write sum_buffers and step_buffers to the file
573 auto optim_tempfile_old_format = c10::make_tempfile();
574 torch::serialize::OutputArchive output_archive;
575 write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
576 write_step_buffers(output_archive, "step_buffers", step_buffers);
577 output_archive.save_to(optim_tempfile_old_format.name);
578 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
579 torch::load, optim1_2, optim_tempfile_old_format.name);
580 is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
581 }
582
TEST(SerializeTest,Optim_SGD)583 TEST(SerializeTest, Optim_SGD) {
584 test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
585 SGDOptions(1e-1).momentum(0.9));
586
587 // bc compatibility check
588 auto model1 = Linear(5, 2);
589 auto model1_params = model1->parameters();
590 // added a tensor for lazy init check - when all params do not have a momentum
591 // buffer entry
592 model1_params.emplace_back(torch::randn({2, 3}));
593 auto optim1 = torch::optim::SGD(
594 model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
595
596 auto x = torch::ones({10, 5});
597 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
598 optimizer.zero_grad();
599 auto y = model->forward(x).sum();
600 y.backward();
601 optimizer.step();
602 };
603 step(optim1, model1);
604
605 std::vector<at::Tensor> momentum_buffers;
606 int64_t iteration_{0};
607 const auto& params_ = optim1.param_groups()[0].params();
608 const auto& optim1_state = optim1.state();
609 for (const auto i : c10::irange(params_.size())) {
610 if (i != (params_.size() - 1)) {
611 auto key_ = params_[i].unsafeGetTensorImpl();
612 const SGDParamState& curr_state_ =
613 static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
614 momentum_buffers.emplace_back(curr_state_.momentum_buffer());
615 }
616 }
617 ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
618 // write momentum_buffers to the file
619 auto optim_tempfile_old_format = c10::make_tempfile();
620 torch::serialize::OutputArchive output_archive;
621 write_tensors_to_archive(
622 output_archive, "momentum_buffers", momentum_buffers);
623 write_int_value(output_archive, "iteration_", iteration_);
624 output_archive.save_to(optim_tempfile_old_format.name);
625 auto optim1_2 =
626 SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
627 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
628 torch::load, optim1_2, optim_tempfile_old_format.name);
629 is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
630 }
631
TEST(SerializeTest,Optim_Adam)632 TEST(SerializeTest, Optim_Adam) {
633 test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
634 AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
635
636 // bc compatibility check
637 auto model1 = Linear(5, 2);
638 auto model1_params = model1->parameters();
639 // added a tensor for lazy init check - when all params do not have entry in
640 // buffers
641 model1_params.emplace_back(torch::randn({2, 3}));
642 auto optim1 = torch::optim::Adam(
643 model1_params, torch::optim::AdamOptions().weight_decay(0.5));
644
645 auto x = torch::ones({10, 5});
646 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
647 optimizer.zero_grad();
648 auto y = model->forward(x).sum();
649 y.backward();
650 optimizer.step();
651 };
652 step(optim1, model1);
653
654 std::vector<int64_t> step_buffers;
655 std::vector<at::Tensor> exp_average_buffers;
656 std::vector<at::Tensor> exp_average_sq_buffers;
657 std::vector<at::Tensor> max_exp_average_sq_buffers;
658 const auto& params_ = optim1.param_groups()[0].params();
659 const auto& optim1_state = optim1.state();
660 for (const auto i : c10::irange(params_.size())) {
661 if (i != (params_.size() - 1)) {
662 auto key_ = params_[i].unsafeGetTensorImpl();
663 const AdamParamState& curr_state_ =
664 static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
665 step_buffers.emplace_back(curr_state_.step());
666 exp_average_buffers.emplace_back(curr_state_.exp_avg());
667 exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
668 if (curr_state_.max_exp_avg_sq().defined()) {
669 max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
670 }
671 }
672 }
673 // write buffers to the file
674 auto optim_tempfile_old_format = c10::make_tempfile();
675 torch::serialize::OutputArchive output_archive;
676 write_step_buffers(output_archive, "step_buffers", step_buffers);
677 write_tensors_to_archive(
678 output_archive, "exp_average_buffers", exp_average_buffers);
679 write_tensors_to_archive(
680 output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
681 write_tensors_to_archive(
682 output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
683 output_archive.save_to(optim_tempfile_old_format.name);
684 auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
685 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
686 torch::load, optim1_2, optim_tempfile_old_format.name);
687 is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
688 }
689
TEST(SerializeTest,Optim_AdamW)690 TEST(SerializeTest, Optim_AdamW) {
691 test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
692 AdamWOptions().lr(0.99999).amsgrad(true).betas(
693 std::make_tuple(0.999, 0.1)));
694
695 // bc compatibility check
696 auto model1 = Linear(5, 2);
697 auto model1_params = model1->parameters();
698 // added a tensor for lazy init check - when all params do not have entry in
699 // buffers
700 model1_params.emplace_back(torch::randn({2, 3}));
701 auto optim1 = torch::optim::AdamW(
702 model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
703
704 auto x = torch::ones({10, 5});
705 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
706 optimizer.zero_grad();
707 auto y = model->forward(x).sum();
708 y.backward();
709 optimizer.step();
710 };
711 step(optim1, model1);
712
713 std::vector<int64_t> step_buffers;
714 std::vector<at::Tensor> exp_average_buffers;
715 std::vector<at::Tensor> exp_average_sq_buffers;
716 std::vector<at::Tensor> max_exp_average_sq_buffers;
717 const auto& params_ = optim1.param_groups()[0].params();
718 const auto& optim1_state = optim1.state();
719 for (const auto i : c10::irange(params_.size())) {
720 if (i != (params_.size() - 1)) {
721 auto key_ = params_[i].unsafeGetTensorImpl();
722 const AdamWParamState& curr_state_ =
723 static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
724 step_buffers.emplace_back(curr_state_.step());
725 exp_average_buffers.emplace_back(curr_state_.exp_avg());
726 exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
727 if (curr_state_.max_exp_avg_sq().defined()) {
728 max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
729 }
730 }
731 }
732 // write buffers to the file
733 auto optim_tempfile_old_format = c10::make_tempfile();
734 torch::serialize::OutputArchive output_archive;
735 write_step_buffers(output_archive, "step_buffers", step_buffers);
736 write_tensors_to_archive(
737 output_archive, "exp_average_buffers", exp_average_buffers);
738 write_tensors_to_archive(
739 output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
740 write_tensors_to_archive(
741 output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
742 output_archive.save_to(optim_tempfile_old_format.name);
743 auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
744 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
745 torch::load, optim1_2, optim_tempfile_old_format.name);
746 is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
747 }
748
TEST(SerializeTest,Optim_RMSprop)749 TEST(SerializeTest, Optim_RMSprop) {
750 auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
751 test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
752
753 // bc compatibility check
754 auto model1 = Linear(5, 2);
755 auto model1_params = model1->parameters();
756
757 // added a tensor for lazy init check - when all params do not have a momentum
758 // buffer entry
759 model1_params.emplace_back(torch::randn({2, 3}));
760 auto optim1 = torch::optim::RMSprop(model1_params, options);
761
762 auto x = torch::ones({10, 5});
763 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
764 optimizer.zero_grad();
765 auto y = model->forward(x).sum();
766 y.backward();
767 optimizer.step();
768 };
769 step(optim1, model1);
770
771 std::vector<at::Tensor> square_average_buffers;
772 std::vector<at::Tensor> momentum_buffers;
773 std::vector<at::Tensor> grad_average_buffers;
774 const auto& params_ = optim1.param_groups()[0].params();
775 const auto& optim1_state = optim1.state();
776 for (const auto i : c10::irange(params_.size())) {
777 if (i != (params_.size() - 1)) {
778 auto key_ = params_[i].unsafeGetTensorImpl();
779 const RMSpropParamState& curr_state_ =
780 static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
781 square_average_buffers.emplace_back(curr_state_.square_avg());
782 if (curr_state_.momentum_buffer().defined()) {
783 momentum_buffers.emplace_back(curr_state_.momentum_buffer());
784 }
785 if (curr_state_.grad_avg().defined()) {
786 grad_average_buffers.emplace_back(curr_state_.grad_avg());
787 }
788 }
789 }
790 // write buffers to the file
791 auto optim_tempfile_old_format = c10::make_tempfile();
792 torch::serialize::OutputArchive output_archive;
793 write_tensors_to_archive(
794 output_archive, "square_average_buffers", square_average_buffers);
795 write_tensors_to_archive(
796 output_archive, "momentum_buffers", momentum_buffers);
797 write_tensors_to_archive(
798 output_archive, "grad_average_buffers", grad_average_buffers);
799 output_archive.save_to(optim_tempfile_old_format.name);
800 auto optim1_2 = RMSprop(model1_params, options);
801 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
802 torch::load, optim1_2, optim_tempfile_old_format.name);
803 const auto& params1_2_ = optim1_2.param_groups()[0].params();
804 auto& optim1_2_state = optim1_2.state();
805 // old RMSprop didn't track step value
806 for (const auto i : c10::irange(params1_2_.size())) {
807 if (i != (params1_2_.size() - 1)) {
808 auto key_ = params_[i].unsafeGetTensorImpl();
809 const RMSpropParamState& curr_state_ =
810 static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
811 RMSpropParamState& curr_state1_2_ =
812 static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
813 curr_state1_2_.step(curr_state_.step());
814 }
815 }
816 is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
817 }
818
TEST(SerializeTest,Optim_LBFGS)819 TEST(SerializeTest, Optim_LBFGS) {
820 test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
821 LBFGSOptions(), true);
822 // bc compatibility check
823 auto model1 = Linear(5, 2);
824 auto model1_params = model1->parameters();
825 // added a tensor for lazy init check - when all params do not have entry in
826 // buffers
827 model1_params.emplace_back(torch::randn({2, 3}));
828 auto optim1 =
829 torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
830
831 auto x = torch::ones({10, 5});
832 auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
833 optimizer.zero_grad();
834 auto y = model->forward(x).sum();
835 y.backward();
836 auto closure = []() { return torch::tensor({10}); };
837 optimizer.step(closure);
838 };
839
840 step(optim1, model1);
841
842 at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
843 std::deque<at::Tensor> old_dirs, old_stps;
844
845 const auto& params_ = optim1.param_groups()[0].params();
846 auto key_ = params_[0].unsafeGetTensorImpl();
847 const auto& optim1_state =
848 static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
849 d = optim1_state.d();
850 t = at::tensor(optim1_state.t());
851 H_diag = optim1_state.H_diag();
852 prev_flat_grad = optim1_state.prev_flat_grad();
853 prev_loss = at::tensor(optim1_state.prev_loss());
854 old_dirs = optim1_state.old_dirs();
855
856 // write buffers to the file
857 auto optim_tempfile_old_format = c10::make_tempfile();
858 torch::serialize::OutputArchive output_archive;
859 output_archive.write("d", d, /*is_buffer=*/true);
860 output_archive.write("t", t, /*is_buffer=*/true);
861 output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
862 output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
863 output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
864 write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
865 write_tensors_to_archive(output_archive, "old_stps", old_stps);
866 output_archive.save_to(optim_tempfile_old_format.name);
867
868 auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
869 OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
870 torch::load, optim1_2, optim_tempfile_old_format.name);
871
872 const auto& params1_2_ = optim1_2.param_groups()[0].params();
873 auto param_key = params1_2_[0].unsafeGetTensorImpl();
874 auto& optim1_2_state =
875 static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
876
877 // old LBFGS didn't track func_evals, n_iter, ro, al values
878 optim1_2_state.func_evals(optim1_state.func_evals());
879 optim1_2_state.n_iter(optim1_state.n_iter());
880 optim1_2_state.ro(optim1_state.ro());
881 optim1_2_state.al(optim1_state.al());
882
883 is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
884 }
885
TEST(SerializeTest,XOR_CUDA)886 TEST(SerializeTest, XOR_CUDA) {
887 torch::manual_seed(0);
888 // We better be able to save and load a XOR model!
889 auto getLoss = [](Sequential model,
890 uint32_t batch_size,
891 bool is_cuda = false) {
892 auto inputs = torch::empty({batch_size, 2});
893 auto labels = torch::empty({batch_size});
894 if (is_cuda) {
895 inputs = inputs.cuda();
896 labels = labels.cuda();
897 }
898 for (const auto i : c10::irange(batch_size)) {
899 inputs[i] = torch::randint(2, {2}, torch::kInt64);
900 labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
901 }
902 auto x = model->forward<torch::Tensor>(inputs);
903 return torch::binary_cross_entropy(x, labels);
904 };
905
906 auto model = xor_model();
907 auto model2 = xor_model();
908 auto model3 = xor_model();
909 auto optimizer = torch::optim::SGD(
910 model->parameters(),
911 torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
912 1e-6));
913
914 float running_loss = 1;
915 int epoch = 0;
916 while (running_loss > 0.1) {
917 torch::Tensor loss = getLoss(model, 4);
918 optimizer.zero_grad();
919 loss.backward();
920 optimizer.step();
921
922 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
923 running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
924 ASSERT_LT(epoch, 3000);
925 epoch++;
926 }
927
928 auto tempfile = c10::make_tempfile();
929 torch::save(model, tempfile.name);
930 torch::load(model2, tempfile.name);
931
932 auto loss = getLoss(model2, 100);
933 ASSERT_LT(loss.item<float>(), 0.1);
934
935 model2->to(torch::kCUDA);
936 loss = getLoss(model2, 100, true);
937 ASSERT_LT(loss.item<float>(), 0.1);
938
939 auto tempfile2 = c10::make_tempfile();
940 torch::save(model2, tempfile2.name);
941 torch::load(model3, tempfile2.name);
942
943 loss = getLoss(model3, 100, true);
944 ASSERT_LT(loss.item<float>(), 0.1);
945 }
946
TEST(SerializeTest,CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers)947 TEST(
948 SerializeTest,
949 CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
950 struct C : torch::nn::Module {
951 C() {
952 register_buffer("foo", torch::ones(5, torch::kInt32));
953 }
954 };
955 struct B : torch::nn::Module {};
956 struct A : torch::nn::Module {
957 A() {
958 register_module("b", std::make_shared<B>());
959 register_module("c", std::make_shared<C>());
960 }
961 };
962 struct M : torch::nn::Module {
963 M() {
964 register_module("a", std::make_shared<A>());
965 }
966 };
967
968 auto out = std::make_shared<M>();
969 std::stringstream ss;
970 torch::save(out, ss);
971 auto in = std::make_shared<M>();
972 torch::load(in, ss);
973
974 const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
975 ASSERT_EQ(output, 5);
976 }
977
TEST(SerializeTest,VectorOfTensors)978 TEST(SerializeTest, VectorOfTensors) {
979 torch::manual_seed(0);
980
981 std::vector<torch::Tensor> x_vec = {
982 torch::randn({1, 2}), torch::randn({3, 4})};
983
984 std::stringstream stream;
985 torch::save(x_vec, stream);
986
987 std::vector<torch::Tensor> y_vec;
988 torch::load(y_vec, stream);
989
990 for (const auto i : c10::irange(x_vec.size())) {
991 auto& x = x_vec[i];
992 auto& y = y_vec[i];
993 ASSERT_TRUE(y.defined());
994 ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
995 ASSERT_TRUE(x.allclose(y));
996 }
997 }
998
TEST(SerializeTest,IValue)999 TEST(SerializeTest, IValue) {
1000 c10::IValue ivalue(1);
1001 auto tempfile = c10::make_tempfile();
1002 torch::serialize::OutputArchive output_archive;
1003 output_archive.write("value", ivalue);
1004 output_archive.save_to(tempfile.name);
1005
1006 torch::serialize::InputArchive input_archive;
1007 input_archive.load_from(tempfile.name);
1008 c10::IValue ivalue_out;
1009 input_archive.read("value", ivalue_out);
1010 ASSERT_EQ(ivalue_out.toInt(), 1);
1011
1012 ASSERT_THROWS_WITH(
1013 input_archive.read("bad_key", ivalue_out),
1014 "does not have a field with name");
1015 }
1016
1017 // NOTE: if a `Module` contains unserializable submodules (e.g.
1018 // `nn::Functional`), we expect those submodules to be skipped when the `Module`
1019 // is being serialized.
TEST(SerializeTest,UnserializableSubmoduleIsSkippedWhenSavingModule)1020 TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1021 struct A : torch::nn::Module {
1022 A() {
1023 register_module("relu", torch::nn::Functional(torch::relu));
1024 }
1025 };
1026
1027 auto out = std::make_shared<A>();
1028 std::stringstream ss;
1029 torch::save(out, ss);
1030
1031 torch::serialize::InputArchive archive;
1032 archive.load_from(ss);
1033 torch::serialize::InputArchive relu_archive;
1034
1035 // Submodule with name "relu" should not exist in the `InputArchive`,
1036 // because the "relu" submodule is an `nn::Functional` and is not
1037 // serializable.
1038 ASSERT_FALSE(archive.try_read("relu", relu_archive));
1039 }
1040
1041 // NOTE: If a `Module` contains unserializable submodules (e.g.
1042 // `nn::Functional`), we don't check the existence of those submodules in the
1043 // `InputArchive` when deserializing.
TEST(SerializeTest,UnserializableSubmoduleIsIgnoredWhenLoadingModule)1044 TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1045 struct B : torch::nn::Module {
1046 B() {
1047 register_module("relu1", torch::nn::Functional(torch::relu));
1048 register_buffer("foo", torch::zeros(5, torch::kInt32));
1049 }
1050 };
1051 struct A : torch::nn::Module {
1052 A() {
1053 register_module("b", std::make_shared<B>());
1054 register_module("relu2", torch::nn::Functional(torch::relu));
1055 }
1056 };
1057
1058 auto out = std::make_shared<A>();
1059 // Manually change the values of "b.foo", so that we can check whether the
1060 // buffer contains these values after deserialization.
1061 out->named_buffers()["b.foo"].fill_(1);
1062 auto tempfile = c10::make_tempfile();
1063 torch::save(out, tempfile.name);
1064
1065 torch::serialize::InputArchive archive;
1066 archive.load_from(tempfile.name);
1067 torch::serialize::InputArchive archive_b;
1068 torch::serialize::InputArchive archive_relu;
1069 torch::Tensor tensor_foo;
1070
1071 ASSERT_TRUE(archive.try_read("b", archive_b));
1072 ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
1073
1074 // Submodule with name "relu1" should not exist in `archive_b`, because the
1075 // "relu1" submodule is an `nn::Functional` and is not serializable.
1076 ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1077
1078 // Submodule with name "relu2" should not exist in `archive`, because the
1079 // "relu2" submodule is an `nn::Functional` and is not serializable.
1080 ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1081
1082 auto in = std::make_shared<A>();
1083 // `torch::load(...)` works without error, even though `A` contains the
1084 // `nn::Functional` submodules while the serialized file doesn't, because the
1085 // `nn::Functional` submodules are not serializable and thus ignored when
1086 // deserializing.
1087 torch::load(in, tempfile.name);
1088
1089 // Check that the "b.foo" buffer is correctly deserialized from the file.
1090 const int output = in->named_buffers()["b.foo"].sum().item<int>();
1091 // `output` should equal to the sum of the values we manually assigned to
1092 // "b.foo" before serialization.
1093 ASSERT_EQ(output, 5);
1094 }
1095