xref: /aosp_15_r20/external/pytorch/test/cpp/api/serialize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/flat_hash_map.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/tempfile.h>
6 
7 #include <torch/torch.h>
8 
9 #include <test/cpp/api/support.h>
10 
11 #include <cstdio>
12 #include <memory>
13 #include <sstream>
14 #include <string>
15 #include <vector>
16 
17 using namespace torch::test;
18 using namespace torch::nn;
19 using namespace torch::optim;
20 
21 namespace {
xor_model()22 Sequential xor_model() {
23   return Sequential(
24       Linear(2, 8),
25       Functional(at::sigmoid),
26       Linear(8, 1),
27       Functional(at::sigmoid));
28 }
29 
save_and_load(torch::Tensor input)30 torch::Tensor save_and_load(torch::Tensor input) {
31   std::stringstream stream;
32   torch::save(input, stream);
33   torch::Tensor tensor;
34   torch::load(tensor, stream);
35   return tensor;
36 }
37 } // namespace
38 
39 template <typename DerivedOptions>
is_optimizer_param_group_equal(const OptimizerParamGroup & lhs,const OptimizerParamGroup & rhs)40 void is_optimizer_param_group_equal(
41     const OptimizerParamGroup& lhs,
42     const OptimizerParamGroup& rhs) {
43   const auto& lhs_params = lhs.params();
44   const auto& rhs_params = rhs.params();
45 
46   ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47   for (const auto j : c10::irange(lhs_params.size())) {
48     ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
49   }
50   ASSERT_TRUE(
51       static_cast<const DerivedOptions&>(lhs.options()) ==
52       static_cast<const DerivedOptions&>(rhs.options()));
53 }
54 
55 template <typename DerivedOptimizerParamState>
is_optimizer_state_equal(const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & lhs_state,const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & rhs_state)56 void is_optimizer_state_equal(
57     const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
58         lhs_state,
59     const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
60         rhs_state) {
61   ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62   for (const auto& value : lhs_state) {
63     auto found = rhs_state.find(value.first);
64     ASSERT_TRUE(found != rhs_state.end());
65     const DerivedOptimizerParamState& lhs_curr_state =
66         static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67     const DerivedOptimizerParamState& rhs_curr_state =
68         static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69     ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
70   }
71 }
72 
73 template <
74     typename OptimizerClass,
75     typename DerivedOptimizerOptions,
76     typename DerivedOptimizerParamState>
test_serialize_optimizer(DerivedOptimizerOptions options,bool only_has_global_state=false)77 void test_serialize_optimizer(
78     DerivedOptimizerOptions options,
79     bool only_has_global_state = false) {
80   torch::manual_seed(0);
81   auto model1 = Linear(5, 2);
82   auto model2 = Linear(5, 2);
83   auto model3 = Linear(5, 2);
84 
85   // Models 1, 2, 3 will have the same parameters.
86   auto model_tempfile = c10::make_tempfile();
87   torch::save(model1, model_tempfile.name);
88   torch::load(model2, model_tempfile.name);
89   torch::load(model3, model_tempfile.name);
90 
91   auto param1 = model1->named_parameters();
92   auto param2 = model2->named_parameters();
93   auto param3 = model3->named_parameters();
94   for (const auto& p : param1) {
95     ASSERT_TRUE(p->allclose(param2[p.key()]));
96     ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
97   }
98   // Make some optimizers
99   auto optim1 = OptimizerClass(
100       {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101   auto optim2 = OptimizerClass(model2->parameters(), options);
102   auto optim2_2 = OptimizerClass(model2->parameters(), options);
103   auto optim3 = OptimizerClass(model3->parameters(), options);
104   auto optim3_2 = OptimizerClass(model3->parameters(), options);
105   for (auto& param_group : optim3_2.param_groups()) {
106     const double lr = param_group.options().get_lr();
107     // change the learning rate, which will be overwritten by the loading
108     // otherwise, test cannot check if options are saved and loaded correctly
109     param_group.options().set_lr(lr + 0.01);
110   }
111 
112   auto x = torch::ones({10, 5});
113 
114   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
115     optimizer.zero_grad();
116     auto y = model->forward(x).sum();
117     y.backward();
118     auto closure = []() { return torch::tensor({10}); };
119     optimizer.step(closure);
120   };
121 
122   // Do 2 steps of model1
123   step(optim1, model1);
124   step(optim1, model1);
125 
126   // Do 2 steps of model 2 without saving the optimizer
127   step(optim2, model2);
128   step(optim2_2, model2);
129 
130   // Do 1 step of model 3
131   step(optim3, model3);
132 
133   // save the optimizer
134   auto optim_tempfile = c10::make_tempfile();
135   torch::save(optim3, optim_tempfile.name);
136   torch::load(optim3_2, optim_tempfile.name);
137 
138   auto& optim3_2_param_groups = optim3_2.param_groups();
139   auto& optim3_param_groups = optim3.param_groups();
140   auto& optim3_2_state = optim3_2.state();
141   auto& optim3_state = optim3.state();
142 
143   // optim3_2 and optim1 should have param_groups and state of size 1 and
144   // state_size respectively
145   ASSERT_TRUE(optim3_2_param_groups.size() == 1);
146   // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
147   // global state
148   unsigned state_size = only_has_global_state ? 1 : 2;
149   ASSERT_TRUE(optim3_2_state.size() == state_size);
150 
151   // optim3_2 and optim1 should have param_groups and state of same size
152   ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
153   ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
154 
155   // checking correctness of serialization logic for optimizer.param_groups_ and
156   // optimizer.state_
157   for (const auto i : c10::irange(optim3_2_param_groups.size())) {
158     is_optimizer_param_group_equal<DerivedOptimizerOptions>(
159         optim3_2_param_groups[i], optim3_param_groups[i]);
160     is_optimizer_state_equal<DerivedOptimizerParamState>(
161         optim3_2_state, optim3_state);
162   }
163 
164   // Do step2 for model 3
165   step(optim3_2, model3);
166 
167   param1 = model1->named_parameters();
168   param2 = model2->named_parameters();
169   param3 = model3->named_parameters();
170   for (const auto& p : param1) {
171     const auto& name = p.key();
172     // Model 1 and 3 should be the same
173     ASSERT_TRUE(
174         param1[name].norm().item<float>() == param3[name].norm().item<float>());
175     ASSERT_TRUE(
176         param1[name].norm().item<float>() != param2[name].norm().item<float>());
177   }
178 }
179 
180 /// Utility function to save a value of `int64_t` type.
write_int_value(torch::serialize::OutputArchive & archive,const std::string & key,const int64_t & value)181 void write_int_value(
182     torch::serialize::OutputArchive& archive,
183     const std::string& key,
184     const int64_t& value) {
185   archive.write(key, c10::IValue(value));
186 }
187 // Utility function to save a vector of buffers.
188 template <typename BufferContainer>
write_tensors_to_archive(torch::serialize::OutputArchive & archive,const std::string & key,const BufferContainer & buffers)189 void write_tensors_to_archive(
190     torch::serialize::OutputArchive& archive,
191     const std::string& key,
192     const BufferContainer& buffers) {
193   archive.write(
194       key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
195   for (const auto index : c10::irange(buffers.size())) {
196     archive.write(
197         key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
198   }
199 }
200 
201 // Utility function to save a vector of step buffers.
write_step_buffers(torch::serialize::OutputArchive & archive,const std::string & key,const std::vector<int64_t> & steps)202 void write_step_buffers(
203     torch::serialize::OutputArchive& archive,
204     const std::string& key,
205     const std::vector<int64_t>& steps) {
206   std::vector<torch::Tensor> tensors;
207   tensors.reserve(steps.size());
208   for (const auto& step : steps) {
209     tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
210   }
211   write_tensors_to_archive(archive, key, tensors);
212 }
213 
214 #define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
215   {                                                                          \
216     WarningCapture warnings;                                                 \
217     funcname(optimizer, filename);                                           \
218     ASSERT_EQ(                                                               \
219         count_substr_occurrences(warnings.str(), "old serialization"), 1);   \
220   }
221 
TEST(SerializeTest,KeysFunc)222 TEST(SerializeTest, KeysFunc) {
223   auto tempfile = c10::make_tempfile();
224   torch::serialize::OutputArchive output_archive;
225   for (const auto i : c10::irange(3)) {
226     output_archive.write(
227         "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
228   }
229   output_archive.save_to(tempfile.name);
230   torch::serialize::InputArchive input_archive;
231   input_archive.load_from(tempfile.name);
232   std::vector<std::string> keys = input_archive.keys();
233   ASSERT_EQ(keys.size(), 3);
234   for (const auto i : c10::irange(keys.size())) {
235     ASSERT_EQ(keys[i], "element/" + std::to_string(i));
236   }
237 }
238 
TEST(SerializeTest,TryReadFunc)239 TEST(SerializeTest, TryReadFunc) {
240   auto tempfile = c10::make_tempfile();
241   torch::serialize::OutputArchive output_archive;
242   for (const auto i : c10::irange(3)) {
243     output_archive.write(
244         "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
245   }
246   output_archive.save_to(tempfile.name);
247   torch::serialize::InputArchive input_archive;
248   input_archive.load_from(tempfile.name);
249   c10::IValue ivalue;
250   ASSERT_FALSE(input_archive.try_read("1", ivalue));
251   ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
252   ASSERT_EQ(ivalue.toInt(), 1);
253 }
254 
TEST(SerializeTest,Basic)255 TEST(SerializeTest, Basic) {
256   torch::manual_seed(0);
257 
258   auto x = torch::randn({5, 5});
259   auto y = save_and_load(x);
260 
261   ASSERT_TRUE(y.defined());
262   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
263   ASSERT_TRUE(x.allclose(y));
264 }
265 
TEST(SerializeTest,MathBits)266 TEST(SerializeTest, MathBits) {
267   torch::manual_seed(0);
268 
269   auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
270   auto x = torch::randn({5, 5}, options);
271   {
272     auto expected = torch::conj(x);
273     auto actual = save_and_load(expected);
274 
275     ASSERT_TRUE(actual.defined());
276     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
277     ASSERT_TRUE(actual.allclose(expected));
278   }
279 
280   {
281     auto expected = torch::_neg_view(x);
282     auto actual = save_and_load(expected);
283 
284     ASSERT_TRUE(actual.defined());
285     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
286     ASSERT_TRUE(actual.allclose(expected));
287   }
288 
289   {
290     auto expected = torch::conj(torch::_neg_view(x));
291     auto actual = save_and_load(expected);
292 
293     ASSERT_TRUE(actual.defined());
294     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
295     ASSERT_TRUE(actual.allclose(expected));
296   }
297 
298   {
299     // We don't support serializing `ZeroTensor` as it is not public facing yet.
300     // If in future, `ZeroTensor` serialization is supported, this test should
301     // start failing!
302     auto t = torch::_efficientzerotensor({5, 5});
303     ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
304   }
305 }
306 
TEST(SerializeTest,BasicToFile)307 TEST(SerializeTest, BasicToFile) {
308   torch::manual_seed(0);
309 
310   auto x = torch::randn({5, 5});
311 
312   auto tempfile = c10::make_tempfile();
313   torch::save(x, tempfile.name);
314 
315   torch::Tensor y;
316   torch::load(y, tempfile.name);
317 
318   ASSERT_TRUE(y.defined());
319   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
320   ASSERT_TRUE(x.allclose(y));
321 }
322 
TEST(SerializeTest,BasicViaFunc)323 TEST(SerializeTest, BasicViaFunc) {
324   torch::manual_seed(0);
325 
326   auto x = torch::randn({5, 5});
327 
328   std::string serialized;
329   torch::save(x, [&](const void* buf, size_t n) {
330     serialized.append(reinterpret_cast<const char*>(buf), n);
331     return n;
332   });
333   torch::Tensor y;
334   torch::load(y, serialized.data(), serialized.size());
335 
336   ASSERT_TRUE(y.defined());
337   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
338   ASSERT_TRUE(x.allclose(y));
339 
340   torch::Tensor z;
341   torch::load(
342       z,
343       [&](uint64_t pos, void* buf, size_t n) -> size_t {
344         if (pos >= serialized.size())
345           return 0;
346         size_t nbytes =
347             std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
348         memcpy(buf, serialized.data() + pos, nbytes);
349         return nbytes;
350       },
351       [&]() -> size_t { return serialized.size(); });
352   ASSERT_TRUE(z.defined());
353   ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
354   ASSERT_TRUE(x.allclose(z));
355 }
356 
TEST(SerializeTest,Resized)357 TEST(SerializeTest, Resized) {
358   torch::manual_seed(0);
359 
360   auto x = torch::randn({11, 5});
361   x.resize_({5, 5});
362   auto y = save_and_load(x);
363 
364   ASSERT_TRUE(y.defined());
365   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
366   ASSERT_TRUE(x.allclose(y));
367 }
368 
TEST(SerializeTest,Sliced)369 TEST(SerializeTest, Sliced) {
370   torch::manual_seed(0);
371 
372   auto x = torch::randn({11, 5});
373   x = x.slice(0, 1, 5);
374   auto y = save_and_load(x);
375 
376   ASSERT_TRUE(y.defined());
377   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
378   ASSERT_TRUE(x.allclose(y));
379 }
380 
TEST(SerializeTest,NonContiguous)381 TEST(SerializeTest, NonContiguous) {
382   torch::manual_seed(0);
383 
384   auto x = torch::randn({11, 5});
385   x = x.slice(1, 1, 4);
386   auto y = save_and_load(x);
387 
388   ASSERT_TRUE(y.defined());
389   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
390   ASSERT_TRUE(x.allclose(y));
391 }
392 
TEST(SerializeTest,ErrorOnMissingKey)393 TEST(SerializeTest, ErrorOnMissingKey) {
394   struct B : torch::nn::Module {
395     B(const std::string& name_c) {
396       register_buffer(name_c, torch::ones(5, torch::kFloat));
397     }
398   };
399   struct A : torch::nn::Module {
400     A(const std::string& name_b, const std::string& name_c) {
401       register_module(name_b, std::make_shared<B>(name_c));
402     }
403   };
404   struct M : torch::nn::Module {
405     M(const std::string& name_a,
406       const std::string& name_b,
407       const std::string& name_c) {
408       register_module(name_a, std::make_shared<A>(name_b, name_c));
409     }
410   };
411 
412   // create a hierarchy of models with names differing below the top level
413   auto model1 = std::make_shared<M>("a", "b", "c");
414   auto model2 = std::make_shared<M>("a", "b", "x");
415   auto model3 = std::make_shared<M>("a", "x", "c");
416 
417   std::stringstream stream;
418   torch::save(model1, stream);
419   // We want the errors to contain hierarchy information, too.
420   ASSERT_THROWS_WITH(
421       torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
422   stream.seekg(0, stream.beg);
423   ASSERT_THROWS_WITH(
424       torch::load(model3, stream), "No such serialized submodule: 'a.x'");
425 }
426 
TEST(SerializeTest,XOR)427 TEST(SerializeTest, XOR) {
428   // We better be able to save and load an XOR model!
429   auto getLoss = [](Sequential model, uint32_t batch_size) {
430     auto inputs = torch::empty({batch_size, 2});
431     auto labels = torch::empty({batch_size});
432     for (const auto i : c10::irange(batch_size)) {
433       inputs[i] = torch::randint(2, {2}, torch::kInt64);
434       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
435     }
436     auto x = model->forward<torch::Tensor>(inputs);
437     return torch::binary_cross_entropy(x, labels);
438   };
439 
440   auto model = xor_model();
441   auto model2 = xor_model();
442   auto model3 = xor_model();
443   auto optimizer = torch::optim::SGD(
444       model->parameters(),
445       torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
446           1e-6));
447 
448   float running_loss = 1;
449   int epoch = 0;
450   while (running_loss > 0.1) {
451     torch::Tensor loss = getLoss(model, 4);
452     optimizer.zero_grad();
453     loss.backward();
454     optimizer.step();
455 
456     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
457     running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
458     ASSERT_LT(epoch, 3000);
459     epoch++;
460   }
461 
462   auto tempfile = c10::make_tempfile();
463   torch::save(model, tempfile.name);
464   torch::load(model2, tempfile.name);
465 
466   auto loss = getLoss(model2, 100);
467   ASSERT_LT(loss.item<float>(), 0.1);
468 }
469 
TEST(SerializeTest,Optim)470 TEST(SerializeTest, Optim) {
471   auto model1 = Linear(5, 2);
472   auto model2 = Linear(5, 2);
473   auto model3 = Linear(5, 2);
474 
475   // Models 1, 2, 3 will have the same parameters.
476   auto model_tempfile = c10::make_tempfile();
477   torch::save(model1, model_tempfile.name);
478   torch::load(model2, model_tempfile.name);
479   torch::load(model3, model_tempfile.name);
480 
481   auto param1 = model1->named_parameters();
482   auto param2 = model2->named_parameters();
483   auto param3 = model3->named_parameters();
484   for (const auto& p : param1) {
485     ASSERT_TRUE(p->allclose(param2[p.key()]));
486     ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
487   }
488 
489   // Make some optimizers with momentum (and thus state)
490   auto optim1 = torch::optim::SGD(
491       model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492   auto optim2 = torch::optim::SGD(
493       model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494   auto optim2_2 = torch::optim::SGD(
495       model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
496   auto optim3 = torch::optim::SGD(
497       model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
498   auto optim3_2 = torch::optim::SGD(
499       model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
500 
501   auto x = torch::ones({10, 5});
502 
503   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
504     optimizer.zero_grad();
505     auto y = model->forward(x).sum();
506     y.backward();
507     optimizer.step();
508   };
509 
510   // Do 2 steps of model1
511   step(optim1, model1);
512   step(optim1, model1);
513 
514   // Do 2 steps of model 2 without saving the optimizer
515   step(optim2, model2);
516   step(optim2_2, model2);
517 
518   // Do 2 steps of model 3 while saving the optimizer
519   step(optim3, model3);
520 
521   auto optim_tempfile = c10::make_tempfile();
522   torch::save(optim3, optim_tempfile.name);
523   torch::load(optim3_2, optim_tempfile.name);
524   step(optim3_2, model3);
525 
526   param1 = model1->named_parameters();
527   param2 = model2->named_parameters();
528   param3 = model3->named_parameters();
529   for (const auto& p : param1) {
530     const auto& name = p.key();
531     // Model 1 and 3 should be the same
532     ASSERT_TRUE(
533         param1[name].norm().item<float>() == param3[name].norm().item<float>());
534     ASSERT_TRUE(
535         param1[name].norm().item<float>() != param2[name].norm().item<float>());
536   }
537 }
538 
TEST(SerializeTest,Optim_Adagrad)539 TEST(SerializeTest, Optim_Adagrad) {
540   test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
541       AdagradOptions(1e-1));
542 
543   // bc compatibility check
544   auto model1 = Linear(5, 2);
545   auto optim1 = torch::optim::Adagrad(
546       model1->parameters(), torch::optim::AdagradOptions(1e-1));
547 
548   auto x = torch::ones({10, 5});
549   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
550     optimizer.zero_grad();
551     auto y = model->forward(x).sum();
552     y.backward();
553     optimizer.step();
554   };
555   step(optim1, model1);
556   auto optim1_2 =
557       Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
558 
559   // fill up with optim1 sum_buffers
560   std::vector<torch::Tensor> sum_buffers;
561   // fill up with optim1 state_buffers
562   std::vector<int64_t> step_buffers;
563   const auto& params_ = optim1.param_groups()[0].params();
564   const auto& optim1_state = optim1.state();
565   for (const auto& param : params_) {
566     auto key_ = param.unsafeGetTensorImpl();
567     const AdagradParamState& curr_state_ =
568         static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
569     sum_buffers.emplace_back(curr_state_.sum());
570     step_buffers.emplace_back(curr_state_.step());
571   }
572   // write sum_buffers and step_buffers to the file
573   auto optim_tempfile_old_format = c10::make_tempfile();
574   torch::serialize::OutputArchive output_archive;
575   write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
576   write_step_buffers(output_archive, "step_buffers", step_buffers);
577   output_archive.save_to(optim_tempfile_old_format.name);
578   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
579       torch::load, optim1_2, optim_tempfile_old_format.name);
580   is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
581 }
582 
TEST(SerializeTest,Optim_SGD)583 TEST(SerializeTest, Optim_SGD) {
584   test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
585       SGDOptions(1e-1).momentum(0.9));
586 
587   // bc compatibility check
588   auto model1 = Linear(5, 2);
589   auto model1_params = model1->parameters();
590   // added a tensor for lazy init check - when all params do not have a momentum
591   // buffer entry
592   model1_params.emplace_back(torch::randn({2, 3}));
593   auto optim1 = torch::optim::SGD(
594       model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
595 
596   auto x = torch::ones({10, 5});
597   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
598     optimizer.zero_grad();
599     auto y = model->forward(x).sum();
600     y.backward();
601     optimizer.step();
602   };
603   step(optim1, model1);
604 
605   std::vector<at::Tensor> momentum_buffers;
606   int64_t iteration_{0};
607   const auto& params_ = optim1.param_groups()[0].params();
608   const auto& optim1_state = optim1.state();
609   for (const auto i : c10::irange(params_.size())) {
610     if (i != (params_.size() - 1)) {
611       auto key_ = params_[i].unsafeGetTensorImpl();
612       const SGDParamState& curr_state_ =
613           static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
614       momentum_buffers.emplace_back(curr_state_.momentum_buffer());
615     }
616   }
617   ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
618   // write momentum_buffers to the file
619   auto optim_tempfile_old_format = c10::make_tempfile();
620   torch::serialize::OutputArchive output_archive;
621   write_tensors_to_archive(
622       output_archive, "momentum_buffers", momentum_buffers);
623   write_int_value(output_archive, "iteration_", iteration_);
624   output_archive.save_to(optim_tempfile_old_format.name);
625   auto optim1_2 =
626       SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
627   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
628       torch::load, optim1_2, optim_tempfile_old_format.name);
629   is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
630 }
631 
TEST(SerializeTest,Optim_Adam)632 TEST(SerializeTest, Optim_Adam) {
633   test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
634       AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
635 
636   // bc compatibility check
637   auto model1 = Linear(5, 2);
638   auto model1_params = model1->parameters();
639   // added a tensor for lazy init check - when all params do not have entry in
640   // buffers
641   model1_params.emplace_back(torch::randn({2, 3}));
642   auto optim1 = torch::optim::Adam(
643       model1_params, torch::optim::AdamOptions().weight_decay(0.5));
644 
645   auto x = torch::ones({10, 5});
646   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
647     optimizer.zero_grad();
648     auto y = model->forward(x).sum();
649     y.backward();
650     optimizer.step();
651   };
652   step(optim1, model1);
653 
654   std::vector<int64_t> step_buffers;
655   std::vector<at::Tensor> exp_average_buffers;
656   std::vector<at::Tensor> exp_average_sq_buffers;
657   std::vector<at::Tensor> max_exp_average_sq_buffers;
658   const auto& params_ = optim1.param_groups()[0].params();
659   const auto& optim1_state = optim1.state();
660   for (const auto i : c10::irange(params_.size())) {
661     if (i != (params_.size() - 1)) {
662       auto key_ = params_[i].unsafeGetTensorImpl();
663       const AdamParamState& curr_state_ =
664           static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
665       step_buffers.emplace_back(curr_state_.step());
666       exp_average_buffers.emplace_back(curr_state_.exp_avg());
667       exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
668       if (curr_state_.max_exp_avg_sq().defined()) {
669         max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
670       }
671     }
672   }
673   // write buffers to the file
674   auto optim_tempfile_old_format = c10::make_tempfile();
675   torch::serialize::OutputArchive output_archive;
676   write_step_buffers(output_archive, "step_buffers", step_buffers);
677   write_tensors_to_archive(
678       output_archive, "exp_average_buffers", exp_average_buffers);
679   write_tensors_to_archive(
680       output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
681   write_tensors_to_archive(
682       output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
683   output_archive.save_to(optim_tempfile_old_format.name);
684   auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
685   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
686       torch::load, optim1_2, optim_tempfile_old_format.name);
687   is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
688 }
689 
TEST(SerializeTest,Optim_AdamW)690 TEST(SerializeTest, Optim_AdamW) {
691   test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
692       AdamWOptions().lr(0.99999).amsgrad(true).betas(
693           std::make_tuple(0.999, 0.1)));
694 
695   // bc compatibility check
696   auto model1 = Linear(5, 2);
697   auto model1_params = model1->parameters();
698   // added a tensor for lazy init check - when all params do not have entry in
699   // buffers
700   model1_params.emplace_back(torch::randn({2, 3}));
701   auto optim1 = torch::optim::AdamW(
702       model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
703 
704   auto x = torch::ones({10, 5});
705   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
706     optimizer.zero_grad();
707     auto y = model->forward(x).sum();
708     y.backward();
709     optimizer.step();
710   };
711   step(optim1, model1);
712 
713   std::vector<int64_t> step_buffers;
714   std::vector<at::Tensor> exp_average_buffers;
715   std::vector<at::Tensor> exp_average_sq_buffers;
716   std::vector<at::Tensor> max_exp_average_sq_buffers;
717   const auto& params_ = optim1.param_groups()[0].params();
718   const auto& optim1_state = optim1.state();
719   for (const auto i : c10::irange(params_.size())) {
720     if (i != (params_.size() - 1)) {
721       auto key_ = params_[i].unsafeGetTensorImpl();
722       const AdamWParamState& curr_state_ =
723           static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
724       step_buffers.emplace_back(curr_state_.step());
725       exp_average_buffers.emplace_back(curr_state_.exp_avg());
726       exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
727       if (curr_state_.max_exp_avg_sq().defined()) {
728         max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
729       }
730     }
731   }
732   // write buffers to the file
733   auto optim_tempfile_old_format = c10::make_tempfile();
734   torch::serialize::OutputArchive output_archive;
735   write_step_buffers(output_archive, "step_buffers", step_buffers);
736   write_tensors_to_archive(
737       output_archive, "exp_average_buffers", exp_average_buffers);
738   write_tensors_to_archive(
739       output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
740   write_tensors_to_archive(
741       output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
742   output_archive.save_to(optim_tempfile_old_format.name);
743   auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
744   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
745       torch::load, optim1_2, optim_tempfile_old_format.name);
746   is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
747 }
748 
TEST(SerializeTest,Optim_RMSprop)749 TEST(SerializeTest, Optim_RMSprop) {
750   auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
751   test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
752 
753   // bc compatibility check
754   auto model1 = Linear(5, 2);
755   auto model1_params = model1->parameters();
756 
757   // added a tensor for lazy init check - when all params do not have a momentum
758   // buffer entry
759   model1_params.emplace_back(torch::randn({2, 3}));
760   auto optim1 = torch::optim::RMSprop(model1_params, options);
761 
762   auto x = torch::ones({10, 5});
763   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
764     optimizer.zero_grad();
765     auto y = model->forward(x).sum();
766     y.backward();
767     optimizer.step();
768   };
769   step(optim1, model1);
770 
771   std::vector<at::Tensor> square_average_buffers;
772   std::vector<at::Tensor> momentum_buffers;
773   std::vector<at::Tensor> grad_average_buffers;
774   const auto& params_ = optim1.param_groups()[0].params();
775   const auto& optim1_state = optim1.state();
776   for (const auto i : c10::irange(params_.size())) {
777     if (i != (params_.size() - 1)) {
778       auto key_ = params_[i].unsafeGetTensorImpl();
779       const RMSpropParamState& curr_state_ =
780           static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
781       square_average_buffers.emplace_back(curr_state_.square_avg());
782       if (curr_state_.momentum_buffer().defined()) {
783         momentum_buffers.emplace_back(curr_state_.momentum_buffer());
784       }
785       if (curr_state_.grad_avg().defined()) {
786         grad_average_buffers.emplace_back(curr_state_.grad_avg());
787       }
788     }
789   }
790   // write buffers to the file
791   auto optim_tempfile_old_format = c10::make_tempfile();
792   torch::serialize::OutputArchive output_archive;
793   write_tensors_to_archive(
794       output_archive, "square_average_buffers", square_average_buffers);
795   write_tensors_to_archive(
796       output_archive, "momentum_buffers", momentum_buffers);
797   write_tensors_to_archive(
798       output_archive, "grad_average_buffers", grad_average_buffers);
799   output_archive.save_to(optim_tempfile_old_format.name);
800   auto optim1_2 = RMSprop(model1_params, options);
801   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
802       torch::load, optim1_2, optim_tempfile_old_format.name);
803   const auto& params1_2_ = optim1_2.param_groups()[0].params();
804   auto& optim1_2_state = optim1_2.state();
805   // old RMSprop didn't track step value
806   for (const auto i : c10::irange(params1_2_.size())) {
807     if (i != (params1_2_.size() - 1)) {
808       auto key_ = params_[i].unsafeGetTensorImpl();
809       const RMSpropParamState& curr_state_ =
810           static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
811       RMSpropParamState& curr_state1_2_ =
812           static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
813       curr_state1_2_.step(curr_state_.step());
814     }
815   }
816   is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
817 }
818 
TEST(SerializeTest,Optim_LBFGS)819 TEST(SerializeTest, Optim_LBFGS) {
820   test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
821       LBFGSOptions(), true);
822   // bc compatibility check
823   auto model1 = Linear(5, 2);
824   auto model1_params = model1->parameters();
825   // added a tensor for lazy init check - when all params do not have entry in
826   // buffers
827   model1_params.emplace_back(torch::randn({2, 3}));
828   auto optim1 =
829       torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
830 
831   auto x = torch::ones({10, 5});
832   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
833     optimizer.zero_grad();
834     auto y = model->forward(x).sum();
835     y.backward();
836     auto closure = []() { return torch::tensor({10}); };
837     optimizer.step(closure);
838   };
839 
840   step(optim1, model1);
841 
842   at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
843   std::deque<at::Tensor> old_dirs, old_stps;
844 
845   const auto& params_ = optim1.param_groups()[0].params();
846   auto key_ = params_[0].unsafeGetTensorImpl();
847   const auto& optim1_state =
848       static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
849   d = optim1_state.d();
850   t = at::tensor(optim1_state.t());
851   H_diag = optim1_state.H_diag();
852   prev_flat_grad = optim1_state.prev_flat_grad();
853   prev_loss = at::tensor(optim1_state.prev_loss());
854   old_dirs = optim1_state.old_dirs();
855 
856   // write buffers to the file
857   auto optim_tempfile_old_format = c10::make_tempfile();
858   torch::serialize::OutputArchive output_archive;
859   output_archive.write("d", d, /*is_buffer=*/true);
860   output_archive.write("t", t, /*is_buffer=*/true);
861   output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
862   output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
863   output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
864   write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
865   write_tensors_to_archive(output_archive, "old_stps", old_stps);
866   output_archive.save_to(optim_tempfile_old_format.name);
867 
868   auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
869   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
870       torch::load, optim1_2, optim_tempfile_old_format.name);
871 
872   const auto& params1_2_ = optim1_2.param_groups()[0].params();
873   auto param_key = params1_2_[0].unsafeGetTensorImpl();
874   auto& optim1_2_state =
875       static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
876 
877   // old LBFGS didn't track func_evals, n_iter, ro, al values
878   optim1_2_state.func_evals(optim1_state.func_evals());
879   optim1_2_state.n_iter(optim1_state.n_iter());
880   optim1_2_state.ro(optim1_state.ro());
881   optim1_2_state.al(optim1_state.al());
882 
883   is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
884 }
885 
TEST(SerializeTest,XOR_CUDA)886 TEST(SerializeTest, XOR_CUDA) {
887   torch::manual_seed(0);
888   // We better be able to save and load a XOR model!
889   auto getLoss = [](Sequential model,
890                     uint32_t batch_size,
891                     bool is_cuda = false) {
892     auto inputs = torch::empty({batch_size, 2});
893     auto labels = torch::empty({batch_size});
894     if (is_cuda) {
895       inputs = inputs.cuda();
896       labels = labels.cuda();
897     }
898     for (const auto i : c10::irange(batch_size)) {
899       inputs[i] = torch::randint(2, {2}, torch::kInt64);
900       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
901     }
902     auto x = model->forward<torch::Tensor>(inputs);
903     return torch::binary_cross_entropy(x, labels);
904   };
905 
906   auto model = xor_model();
907   auto model2 = xor_model();
908   auto model3 = xor_model();
909   auto optimizer = torch::optim::SGD(
910       model->parameters(),
911       torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
912           1e-6));
913 
914   float running_loss = 1;
915   int epoch = 0;
916   while (running_loss > 0.1) {
917     torch::Tensor loss = getLoss(model, 4);
918     optimizer.zero_grad();
919     loss.backward();
920     optimizer.step();
921 
922     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
923     running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
924     ASSERT_LT(epoch, 3000);
925     epoch++;
926   }
927 
928   auto tempfile = c10::make_tempfile();
929   torch::save(model, tempfile.name);
930   torch::load(model2, tempfile.name);
931 
932   auto loss = getLoss(model2, 100);
933   ASSERT_LT(loss.item<float>(), 0.1);
934 
935   model2->to(torch::kCUDA);
936   loss = getLoss(model2, 100, true);
937   ASSERT_LT(loss.item<float>(), 0.1);
938 
939   auto tempfile2 = c10::make_tempfile();
940   torch::save(model2, tempfile2.name);
941   torch::load(model3, tempfile2.name);
942 
943   loss = getLoss(model3, 100, true);
944   ASSERT_LT(loss.item<float>(), 0.1);
945 }
946 
TEST(SerializeTest,CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers)947 TEST(
948     SerializeTest,
949     CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
950   struct C : torch::nn::Module {
951     C() {
952       register_buffer("foo", torch::ones(5, torch::kInt32));
953     }
954   };
955   struct B : torch::nn::Module {};
956   struct A : torch::nn::Module {
957     A() {
958       register_module("b", std::make_shared<B>());
959       register_module("c", std::make_shared<C>());
960     }
961   };
962   struct M : torch::nn::Module {
963     M() {
964       register_module("a", std::make_shared<A>());
965     }
966   };
967 
968   auto out = std::make_shared<M>();
969   std::stringstream ss;
970   torch::save(out, ss);
971   auto in = std::make_shared<M>();
972   torch::load(in, ss);
973 
974   const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
975   ASSERT_EQ(output, 5);
976 }
977 
TEST(SerializeTest,VectorOfTensors)978 TEST(SerializeTest, VectorOfTensors) {
979   torch::manual_seed(0);
980 
981   std::vector<torch::Tensor> x_vec = {
982       torch::randn({1, 2}), torch::randn({3, 4})};
983 
984   std::stringstream stream;
985   torch::save(x_vec, stream);
986 
987   std::vector<torch::Tensor> y_vec;
988   torch::load(y_vec, stream);
989 
990   for (const auto i : c10::irange(x_vec.size())) {
991     auto& x = x_vec[i];
992     auto& y = y_vec[i];
993     ASSERT_TRUE(y.defined());
994     ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
995     ASSERT_TRUE(x.allclose(y));
996   }
997 }
998 
TEST(SerializeTest,IValue)999 TEST(SerializeTest, IValue) {
1000   c10::IValue ivalue(1);
1001   auto tempfile = c10::make_tempfile();
1002   torch::serialize::OutputArchive output_archive;
1003   output_archive.write("value", ivalue);
1004   output_archive.save_to(tempfile.name);
1005 
1006   torch::serialize::InputArchive input_archive;
1007   input_archive.load_from(tempfile.name);
1008   c10::IValue ivalue_out;
1009   input_archive.read("value", ivalue_out);
1010   ASSERT_EQ(ivalue_out.toInt(), 1);
1011 
1012   ASSERT_THROWS_WITH(
1013       input_archive.read("bad_key", ivalue_out),
1014       "does not have a field with name");
1015 }
1016 
1017 // NOTE: if a `Module` contains unserializable submodules (e.g.
1018 // `nn::Functional`), we expect those submodules to be skipped when the `Module`
1019 // is being serialized.
TEST(SerializeTest,UnserializableSubmoduleIsSkippedWhenSavingModule)1020 TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1021   struct A : torch::nn::Module {
1022     A() {
1023       register_module("relu", torch::nn::Functional(torch::relu));
1024     }
1025   };
1026 
1027   auto out = std::make_shared<A>();
1028   std::stringstream ss;
1029   torch::save(out, ss);
1030 
1031   torch::serialize::InputArchive archive;
1032   archive.load_from(ss);
1033   torch::serialize::InputArchive relu_archive;
1034 
1035   // Submodule with name "relu" should not exist in the `InputArchive`,
1036   // because the "relu" submodule is an `nn::Functional` and is not
1037   // serializable.
1038   ASSERT_FALSE(archive.try_read("relu", relu_archive));
1039 }
1040 
1041 // NOTE: If a `Module` contains unserializable submodules (e.g.
1042 // `nn::Functional`), we don't check the existence of those submodules in the
1043 // `InputArchive` when deserializing.
TEST(SerializeTest,UnserializableSubmoduleIsIgnoredWhenLoadingModule)1044 TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1045   struct B : torch::nn::Module {
1046     B() {
1047       register_module("relu1", torch::nn::Functional(torch::relu));
1048       register_buffer("foo", torch::zeros(5, torch::kInt32));
1049     }
1050   };
1051   struct A : torch::nn::Module {
1052     A() {
1053       register_module("b", std::make_shared<B>());
1054       register_module("relu2", torch::nn::Functional(torch::relu));
1055     }
1056   };
1057 
1058   auto out = std::make_shared<A>();
1059   // Manually change the values of "b.foo", so that we can check whether the
1060   // buffer contains these values after deserialization.
1061   out->named_buffers()["b.foo"].fill_(1);
1062   auto tempfile = c10::make_tempfile();
1063   torch::save(out, tempfile.name);
1064 
1065   torch::serialize::InputArchive archive;
1066   archive.load_from(tempfile.name);
1067   torch::serialize::InputArchive archive_b;
1068   torch::serialize::InputArchive archive_relu;
1069   torch::Tensor tensor_foo;
1070 
1071   ASSERT_TRUE(archive.try_read("b", archive_b));
1072   ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
1073 
1074   // Submodule with name "relu1" should not exist in `archive_b`, because the
1075   // "relu1" submodule is an `nn::Functional` and is not serializable.
1076   ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1077 
1078   // Submodule with name "relu2" should not exist in `archive`, because the
1079   // "relu2" submodule is an `nn::Functional` and is not serializable.
1080   ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1081 
1082   auto in = std::make_shared<A>();
1083   // `torch::load(...)` works without error, even though `A` contains the
1084   // `nn::Functional` submodules while the serialized file doesn't, because the
1085   // `nn::Functional` submodules are not serializable and thus ignored when
1086   // deserializing.
1087   torch::load(in, tempfile.name);
1088 
1089   // Check that the "b.foo" buffer is correctly deserialized from the file.
1090   const int output = in->named_buffers()["b.foo"].sum().item<int>();
1091   // `output` should equal to the sum of the values we manually assigned to
1092   // "b.foo" before serialization.
1093   ASSERT_EQ(output, 5);
1094 }
1095