xref: /aosp_15_r20/external/pytorch/test/cpp/api/serialize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/flat_hash_map.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/tempfile.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker #include <cstdio>
12*da0073e9SAndroid Build Coastguard Worker #include <memory>
13*da0073e9SAndroid Build Coastguard Worker #include <sstream>
14*da0073e9SAndroid Build Coastguard Worker #include <string>
15*da0073e9SAndroid Build Coastguard Worker #include <vector>
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
18*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
19*da0073e9SAndroid Build Coastguard Worker using namespace torch::optim;
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker namespace {
xor_model()22*da0073e9SAndroid Build Coastguard Worker Sequential xor_model() {
23*da0073e9SAndroid Build Coastguard Worker   return Sequential(
24*da0073e9SAndroid Build Coastguard Worker       Linear(2, 8),
25*da0073e9SAndroid Build Coastguard Worker       Functional(at::sigmoid),
26*da0073e9SAndroid Build Coastguard Worker       Linear(8, 1),
27*da0073e9SAndroid Build Coastguard Worker       Functional(at::sigmoid));
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker 
save_and_load(torch::Tensor input)30*da0073e9SAndroid Build Coastguard Worker torch::Tensor save_and_load(torch::Tensor input) {
31*da0073e9SAndroid Build Coastguard Worker   std::stringstream stream;
32*da0073e9SAndroid Build Coastguard Worker   torch::save(input, stream);
33*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tensor;
34*da0073e9SAndroid Build Coastguard Worker   torch::load(tensor, stream);
35*da0073e9SAndroid Build Coastguard Worker   return tensor;
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker } // namespace
38*da0073e9SAndroid Build Coastguard Worker 
39*da0073e9SAndroid Build Coastguard Worker template <typename DerivedOptions>
is_optimizer_param_group_equal(const OptimizerParamGroup & lhs,const OptimizerParamGroup & rhs)40*da0073e9SAndroid Build Coastguard Worker void is_optimizer_param_group_equal(
41*da0073e9SAndroid Build Coastguard Worker     const OptimizerParamGroup& lhs,
42*da0073e9SAndroid Build Coastguard Worker     const OptimizerParamGroup& rhs) {
43*da0073e9SAndroid Build Coastguard Worker   const auto& lhs_params = lhs.params();
44*da0073e9SAndroid Build Coastguard Worker   const auto& rhs_params = rhs.params();
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(lhs_params.size() == rhs_params.size());
47*da0073e9SAndroid Build Coastguard Worker   for (const auto j : c10::irange(lhs_params.size())) {
48*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j]));
49*da0073e9SAndroid Build Coastguard Worker   }
50*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
51*da0073e9SAndroid Build Coastguard Worker       static_cast<const DerivedOptions&>(lhs.options()) ==
52*da0073e9SAndroid Build Coastguard Worker       static_cast<const DerivedOptions&>(rhs.options()));
53*da0073e9SAndroid Build Coastguard Worker }
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker template <typename DerivedOptimizerParamState>
is_optimizer_state_equal(const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & lhs_state,const ska::flat_hash_map<void *,std::unique_ptr<OptimizerParamState>> & rhs_state)56*da0073e9SAndroid Build Coastguard Worker void is_optimizer_state_equal(
57*da0073e9SAndroid Build Coastguard Worker     const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
58*da0073e9SAndroid Build Coastguard Worker         lhs_state,
59*da0073e9SAndroid Build Coastguard Worker     const ska::flat_hash_map<void*, std::unique_ptr<OptimizerParamState>>&
60*da0073e9SAndroid Build Coastguard Worker         rhs_state) {
61*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(lhs_state.size() == rhs_state.size());
62*da0073e9SAndroid Build Coastguard Worker   for (const auto& value : lhs_state) {
63*da0073e9SAndroid Build Coastguard Worker     auto found = rhs_state.find(value.first);
64*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(found != rhs_state.end());
65*da0073e9SAndroid Build Coastguard Worker     const DerivedOptimizerParamState& lhs_curr_state =
66*da0073e9SAndroid Build Coastguard Worker         static_cast<const DerivedOptimizerParamState&>(*(value.second.get()));
67*da0073e9SAndroid Build Coastguard Worker     const DerivedOptimizerParamState& rhs_curr_state =
68*da0073e9SAndroid Build Coastguard Worker         static_cast<const DerivedOptimizerParamState&>(*(found->second.get()));
69*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(lhs_curr_state == rhs_curr_state);
70*da0073e9SAndroid Build Coastguard Worker   }
71*da0073e9SAndroid Build Coastguard Worker }
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker template <
74*da0073e9SAndroid Build Coastguard Worker     typename OptimizerClass,
75*da0073e9SAndroid Build Coastguard Worker     typename DerivedOptimizerOptions,
76*da0073e9SAndroid Build Coastguard Worker     typename DerivedOptimizerParamState>
test_serialize_optimizer(DerivedOptimizerOptions options,bool only_has_global_state=false)77*da0073e9SAndroid Build Coastguard Worker void test_serialize_optimizer(
78*da0073e9SAndroid Build Coastguard Worker     DerivedOptimizerOptions options,
79*da0073e9SAndroid Build Coastguard Worker     bool only_has_global_state = false) {
80*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
81*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
82*da0073e9SAndroid Build Coastguard Worker   auto model2 = Linear(5, 2);
83*da0073e9SAndroid Build Coastguard Worker   auto model3 = Linear(5, 2);
84*da0073e9SAndroid Build Coastguard Worker 
85*da0073e9SAndroid Build Coastguard Worker   // Models 1, 2, 3 will have the same parameters.
86*da0073e9SAndroid Build Coastguard Worker   auto model_tempfile = c10::make_tempfile();
87*da0073e9SAndroid Build Coastguard Worker   torch::save(model1, model_tempfile.name);
88*da0073e9SAndroid Build Coastguard Worker   torch::load(model2, model_tempfile.name);
89*da0073e9SAndroid Build Coastguard Worker   torch::load(model3, model_tempfile.name);
90*da0073e9SAndroid Build Coastguard Worker 
91*da0073e9SAndroid Build Coastguard Worker   auto param1 = model1->named_parameters();
92*da0073e9SAndroid Build Coastguard Worker   auto param2 = model2->named_parameters();
93*da0073e9SAndroid Build Coastguard Worker   auto param3 = model3->named_parameters();
94*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : param1) {
95*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(p->allclose(param2[p.key()]));
96*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
97*da0073e9SAndroid Build Coastguard Worker   }
98*da0073e9SAndroid Build Coastguard Worker   // Make some optimizers
99*da0073e9SAndroid Build Coastguard Worker   auto optim1 = OptimizerClass(
100*da0073e9SAndroid Build Coastguard Worker       {torch::optim::OptimizerParamGroup(model1->parameters())}, options);
101*da0073e9SAndroid Build Coastguard Worker   auto optim2 = OptimizerClass(model2->parameters(), options);
102*da0073e9SAndroid Build Coastguard Worker   auto optim2_2 = OptimizerClass(model2->parameters(), options);
103*da0073e9SAndroid Build Coastguard Worker   auto optim3 = OptimizerClass(model3->parameters(), options);
104*da0073e9SAndroid Build Coastguard Worker   auto optim3_2 = OptimizerClass(model3->parameters(), options);
105*da0073e9SAndroid Build Coastguard Worker   for (auto& param_group : optim3_2.param_groups()) {
106*da0073e9SAndroid Build Coastguard Worker     const double lr = param_group.options().get_lr();
107*da0073e9SAndroid Build Coastguard Worker     // change the learning rate, which will be overwritten by the loading
108*da0073e9SAndroid Build Coastguard Worker     // otherwise, test cannot check if options are saved and loaded correctly
109*da0073e9SAndroid Build Coastguard Worker     param_group.options().set_lr(lr + 0.01);
110*da0073e9SAndroid Build Coastguard Worker   }
111*da0073e9SAndroid Build Coastguard Worker 
112*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
113*da0073e9SAndroid Build Coastguard Worker 
114*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
115*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
116*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
117*da0073e9SAndroid Build Coastguard Worker     y.backward();
118*da0073e9SAndroid Build Coastguard Worker     auto closure = []() { return torch::tensor({10}); };
119*da0073e9SAndroid Build Coastguard Worker     optimizer.step(closure);
120*da0073e9SAndroid Build Coastguard Worker   };
121*da0073e9SAndroid Build Coastguard Worker 
122*da0073e9SAndroid Build Coastguard Worker   // Do 2 steps of model1
123*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
124*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker   // Do 2 steps of model 2 without saving the optimizer
127*da0073e9SAndroid Build Coastguard Worker   step(optim2, model2);
128*da0073e9SAndroid Build Coastguard Worker   step(optim2_2, model2);
129*da0073e9SAndroid Build Coastguard Worker 
130*da0073e9SAndroid Build Coastguard Worker   // Do 1 step of model 3
131*da0073e9SAndroid Build Coastguard Worker   step(optim3, model3);
132*da0073e9SAndroid Build Coastguard Worker 
133*da0073e9SAndroid Build Coastguard Worker   // save the optimizer
134*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile = c10::make_tempfile();
135*da0073e9SAndroid Build Coastguard Worker   torch::save(optim3, optim_tempfile.name);
136*da0073e9SAndroid Build Coastguard Worker   torch::load(optim3_2, optim_tempfile.name);
137*da0073e9SAndroid Build Coastguard Worker 
138*da0073e9SAndroid Build Coastguard Worker   auto& optim3_2_param_groups = optim3_2.param_groups();
139*da0073e9SAndroid Build Coastguard Worker   auto& optim3_param_groups = optim3.param_groups();
140*da0073e9SAndroid Build Coastguard Worker   auto& optim3_2_state = optim3_2.state();
141*da0073e9SAndroid Build Coastguard Worker   auto& optim3_state = optim3.state();
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   // optim3_2 and optim1 should have param_groups and state of size 1 and
144*da0073e9SAndroid Build Coastguard Worker   // state_size respectively
145*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(optim3_2_param_groups.size() == 1);
146*da0073e9SAndroid Build Coastguard Worker   // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one
147*da0073e9SAndroid Build Coastguard Worker   // global state
148*da0073e9SAndroid Build Coastguard Worker   unsigned state_size = only_has_global_state ? 1 : 2;
149*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(optim3_2_state.size() == state_size);
150*da0073e9SAndroid Build Coastguard Worker 
151*da0073e9SAndroid Build Coastguard Worker   // optim3_2 and optim1 should have param_groups and state of same size
152*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size());
153*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(optim3_2_state.size() == optim3_state.size());
154*da0073e9SAndroid Build Coastguard Worker 
155*da0073e9SAndroid Build Coastguard Worker   // checking correctness of serialization logic for optimizer.param_groups_ and
156*da0073e9SAndroid Build Coastguard Worker   // optimizer.state_
157*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(optim3_2_param_groups.size())) {
158*da0073e9SAndroid Build Coastguard Worker     is_optimizer_param_group_equal<DerivedOptimizerOptions>(
159*da0073e9SAndroid Build Coastguard Worker         optim3_2_param_groups[i], optim3_param_groups[i]);
160*da0073e9SAndroid Build Coastguard Worker     is_optimizer_state_equal<DerivedOptimizerParamState>(
161*da0073e9SAndroid Build Coastguard Worker         optim3_2_state, optim3_state);
162*da0073e9SAndroid Build Coastguard Worker   }
163*da0073e9SAndroid Build Coastguard Worker 
164*da0073e9SAndroid Build Coastguard Worker   // Do step2 for model 3
165*da0073e9SAndroid Build Coastguard Worker   step(optim3_2, model3);
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker   param1 = model1->named_parameters();
168*da0073e9SAndroid Build Coastguard Worker   param2 = model2->named_parameters();
169*da0073e9SAndroid Build Coastguard Worker   param3 = model3->named_parameters();
170*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : param1) {
171*da0073e9SAndroid Build Coastguard Worker     const auto& name = p.key();
172*da0073e9SAndroid Build Coastguard Worker     // Model 1 and 3 should be the same
173*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
174*da0073e9SAndroid Build Coastguard Worker         param1[name].norm().item<float>() == param3[name].norm().item<float>());
175*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
176*da0073e9SAndroid Build Coastguard Worker         param1[name].norm().item<float>() != param2[name].norm().item<float>());
177*da0073e9SAndroid Build Coastguard Worker   }
178*da0073e9SAndroid Build Coastguard Worker }
179*da0073e9SAndroid Build Coastguard Worker 
180*da0073e9SAndroid Build Coastguard Worker /// Utility function to save a value of `int64_t` type.
write_int_value(torch::serialize::OutputArchive & archive,const std::string & key,const int64_t & value)181*da0073e9SAndroid Build Coastguard Worker void write_int_value(
182*da0073e9SAndroid Build Coastguard Worker     torch::serialize::OutputArchive& archive,
183*da0073e9SAndroid Build Coastguard Worker     const std::string& key,
184*da0073e9SAndroid Build Coastguard Worker     const int64_t& value) {
185*da0073e9SAndroid Build Coastguard Worker   archive.write(key, c10::IValue(value));
186*da0073e9SAndroid Build Coastguard Worker }
187*da0073e9SAndroid Build Coastguard Worker // Utility function to save a vector of buffers.
188*da0073e9SAndroid Build Coastguard Worker template <typename BufferContainer>
write_tensors_to_archive(torch::serialize::OutputArchive & archive,const std::string & key,const BufferContainer & buffers)189*da0073e9SAndroid Build Coastguard Worker void write_tensors_to_archive(
190*da0073e9SAndroid Build Coastguard Worker     torch::serialize::OutputArchive& archive,
191*da0073e9SAndroid Build Coastguard Worker     const std::string& key,
192*da0073e9SAndroid Build Coastguard Worker     const BufferContainer& buffers) {
193*da0073e9SAndroid Build Coastguard Worker   archive.write(
194*da0073e9SAndroid Build Coastguard Worker       key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
195*da0073e9SAndroid Build Coastguard Worker   for (const auto index : c10::irange(buffers.size())) {
196*da0073e9SAndroid Build Coastguard Worker     archive.write(
197*da0073e9SAndroid Build Coastguard Worker         key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
198*da0073e9SAndroid Build Coastguard Worker   }
199*da0073e9SAndroid Build Coastguard Worker }
200*da0073e9SAndroid Build Coastguard Worker 
201*da0073e9SAndroid Build Coastguard Worker // Utility function to save a vector of step buffers.
write_step_buffers(torch::serialize::OutputArchive & archive,const std::string & key,const std::vector<int64_t> & steps)202*da0073e9SAndroid Build Coastguard Worker void write_step_buffers(
203*da0073e9SAndroid Build Coastguard Worker     torch::serialize::OutputArchive& archive,
204*da0073e9SAndroid Build Coastguard Worker     const std::string& key,
205*da0073e9SAndroid Build Coastguard Worker     const std::vector<int64_t>& steps) {
206*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> tensors;
207*da0073e9SAndroid Build Coastguard Worker   tensors.reserve(steps.size());
208*da0073e9SAndroid Build Coastguard Worker   for (const auto& step : steps) {
209*da0073e9SAndroid Build Coastguard Worker     tensors.push_back(torch::tensor(static_cast<int64_t>(step)));
210*da0073e9SAndroid Build Coastguard Worker   }
211*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(archive, key, tensors);
212*da0073e9SAndroid Build Coastguard Worker }
213*da0073e9SAndroid Build Coastguard Worker 
214*da0073e9SAndroid Build Coastguard Worker #define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \
215*da0073e9SAndroid Build Coastguard Worker   {                                                                          \
216*da0073e9SAndroid Build Coastguard Worker     WarningCapture warnings;                                                 \
217*da0073e9SAndroid Build Coastguard Worker     funcname(optimizer, filename);                                           \
218*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(                                                               \
219*da0073e9SAndroid Build Coastguard Worker         count_substr_occurrences(warnings.str(), "old serialization"), 1);   \
220*da0073e9SAndroid Build Coastguard Worker   }
221*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,KeysFunc)222*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, KeysFunc) {
223*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
224*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
225*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(3)) {
226*da0073e9SAndroid Build Coastguard Worker     output_archive.write(
227*da0073e9SAndroid Build Coastguard Worker         "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
228*da0073e9SAndroid Build Coastguard Worker   }
229*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(tempfile.name);
230*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive input_archive;
231*da0073e9SAndroid Build Coastguard Worker   input_archive.load_from(tempfile.name);
232*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> keys = input_archive.keys();
233*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(keys.size(), 3);
234*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(keys.size())) {
235*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(keys[i], "element/" + std::to_string(i));
236*da0073e9SAndroid Build Coastguard Worker   }
237*da0073e9SAndroid Build Coastguard Worker }
238*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,TryReadFunc)239*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, TryReadFunc) {
240*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
241*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
242*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(3)) {
243*da0073e9SAndroid Build Coastguard Worker     output_archive.write(
244*da0073e9SAndroid Build Coastguard Worker         "element/" + std::to_string(i), c10::IValue(static_cast<int64_t>(i)));
245*da0073e9SAndroid Build Coastguard Worker   }
246*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(tempfile.name);
247*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive input_archive;
248*da0073e9SAndroid Build Coastguard Worker   input_archive.load_from(tempfile.name);
249*da0073e9SAndroid Build Coastguard Worker   c10::IValue ivalue;
250*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(input_archive.try_read("1", ivalue));
251*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(input_archive.try_read("element/1", ivalue));
252*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(ivalue.toInt(), 1);
253*da0073e9SAndroid Build Coastguard Worker }
254*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Basic)255*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Basic) {
256*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
257*da0073e9SAndroid Build Coastguard Worker 
258*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({5, 5});
259*da0073e9SAndroid Build Coastguard Worker   auto y = save_and_load(x);
260*da0073e9SAndroid Build Coastguard Worker 
261*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
262*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
263*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
264*da0073e9SAndroid Build Coastguard Worker }
265*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,MathBits)266*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, MathBits) {
267*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
268*da0073e9SAndroid Build Coastguard Worker 
269*da0073e9SAndroid Build Coastguard Worker   auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat);
270*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({5, 5}, options);
271*da0073e9SAndroid Build Coastguard Worker   {
272*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::conj(x);
273*da0073e9SAndroid Build Coastguard Worker     auto actual = save_and_load(expected);
274*da0073e9SAndroid Build Coastguard Worker 
275*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.defined());
276*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
277*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.allclose(expected));
278*da0073e9SAndroid Build Coastguard Worker   }
279*da0073e9SAndroid Build Coastguard Worker 
280*da0073e9SAndroid Build Coastguard Worker   {
281*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::_neg_view(x);
282*da0073e9SAndroid Build Coastguard Worker     auto actual = save_and_load(expected);
283*da0073e9SAndroid Build Coastguard Worker 
284*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.defined());
285*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
286*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.allclose(expected));
287*da0073e9SAndroid Build Coastguard Worker   }
288*da0073e9SAndroid Build Coastguard Worker 
289*da0073e9SAndroid Build Coastguard Worker   {
290*da0073e9SAndroid Build Coastguard Worker     auto expected = torch::conj(torch::_neg_view(x));
291*da0073e9SAndroid Build Coastguard Worker     auto actual = save_and_load(expected);
292*da0073e9SAndroid Build Coastguard Worker 
293*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.defined());
294*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec());
295*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(actual.allclose(expected));
296*da0073e9SAndroid Build Coastguard Worker   }
297*da0073e9SAndroid Build Coastguard Worker 
298*da0073e9SAndroid Build Coastguard Worker   {
299*da0073e9SAndroid Build Coastguard Worker     // We don't support serializing `ZeroTensor` as it is not public facing yet.
300*da0073e9SAndroid Build Coastguard Worker     // If in future, `ZeroTensor` serialization is supported, this test should
301*da0073e9SAndroid Build Coastguard Worker     // start failing!
302*da0073e9SAndroid Build Coastguard Worker     auto t = torch::_efficientzerotensor({5, 5});
303*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,");
304*da0073e9SAndroid Build Coastguard Worker   }
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,BasicToFile)307*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, BasicToFile) {
308*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
309*da0073e9SAndroid Build Coastguard Worker 
310*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({5, 5});
311*da0073e9SAndroid Build Coastguard Worker 
312*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
313*da0073e9SAndroid Build Coastguard Worker   torch::save(x, tempfile.name);
314*da0073e9SAndroid Build Coastguard Worker 
315*da0073e9SAndroid Build Coastguard Worker   torch::Tensor y;
316*da0073e9SAndroid Build Coastguard Worker   torch::load(y, tempfile.name);
317*da0073e9SAndroid Build Coastguard Worker 
318*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
319*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
320*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
321*da0073e9SAndroid Build Coastguard Worker }
322*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,BasicViaFunc)323*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, BasicViaFunc) {
324*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
325*da0073e9SAndroid Build Coastguard Worker 
326*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({5, 5});
327*da0073e9SAndroid Build Coastguard Worker 
328*da0073e9SAndroid Build Coastguard Worker   std::string serialized;
329*da0073e9SAndroid Build Coastguard Worker   torch::save(x, [&](const void* buf, size_t n) {
330*da0073e9SAndroid Build Coastguard Worker     serialized.append(reinterpret_cast<const char*>(buf), n);
331*da0073e9SAndroid Build Coastguard Worker     return n;
332*da0073e9SAndroid Build Coastguard Worker   });
333*da0073e9SAndroid Build Coastguard Worker   torch::Tensor y;
334*da0073e9SAndroid Build Coastguard Worker   torch::load(y, serialized.data(), serialized.size());
335*da0073e9SAndroid Build Coastguard Worker 
336*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
337*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
338*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
339*da0073e9SAndroid Build Coastguard Worker 
340*da0073e9SAndroid Build Coastguard Worker   torch::Tensor z;
341*da0073e9SAndroid Build Coastguard Worker   torch::load(
342*da0073e9SAndroid Build Coastguard Worker       z,
343*da0073e9SAndroid Build Coastguard Worker       [&](uint64_t pos, void* buf, size_t n) -> size_t {
344*da0073e9SAndroid Build Coastguard Worker         if (pos >= serialized.size())
345*da0073e9SAndroid Build Coastguard Worker           return 0;
346*da0073e9SAndroid Build Coastguard Worker         size_t nbytes =
347*da0073e9SAndroid Build Coastguard Worker             std::min(static_cast<size_t>(pos) + n, serialized.size()) - pos;
348*da0073e9SAndroid Build Coastguard Worker         memcpy(buf, serialized.data() + pos, nbytes);
349*da0073e9SAndroid Build Coastguard Worker         return nbytes;
350*da0073e9SAndroid Build Coastguard Worker       },
351*da0073e9SAndroid Build Coastguard Worker       [&]() -> size_t { return serialized.size(); });
352*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(z.defined());
353*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), z.sizes().vec());
354*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(z));
355*da0073e9SAndroid Build Coastguard Worker }
356*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Resized)357*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Resized) {
358*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
359*da0073e9SAndroid Build Coastguard Worker 
360*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({11, 5});
361*da0073e9SAndroid Build Coastguard Worker   x.resize_({5, 5});
362*da0073e9SAndroid Build Coastguard Worker   auto y = save_and_load(x);
363*da0073e9SAndroid Build Coastguard Worker 
364*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
365*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
366*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
367*da0073e9SAndroid Build Coastguard Worker }
368*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Sliced)369*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Sliced) {
370*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
371*da0073e9SAndroid Build Coastguard Worker 
372*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({11, 5});
373*da0073e9SAndroid Build Coastguard Worker   x = x.slice(0, 1, 5);
374*da0073e9SAndroid Build Coastguard Worker   auto y = save_and_load(x);
375*da0073e9SAndroid Build Coastguard Worker 
376*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
377*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
378*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
379*da0073e9SAndroid Build Coastguard Worker }
380*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,NonContiguous)381*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, NonContiguous) {
382*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
383*da0073e9SAndroid Build Coastguard Worker 
384*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({11, 5});
385*da0073e9SAndroid Build Coastguard Worker   x = x.slice(1, 1, 4);
386*da0073e9SAndroid Build Coastguard Worker   auto y = save_and_load(x);
387*da0073e9SAndroid Build Coastguard Worker 
388*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(y.defined());
389*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
390*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.allclose(y));
391*da0073e9SAndroid Build Coastguard Worker }
392*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,ErrorOnMissingKey)393*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, ErrorOnMissingKey) {
394*da0073e9SAndroid Build Coastguard Worker   struct B : torch::nn::Module {
395*da0073e9SAndroid Build Coastguard Worker     B(const std::string& name_c) {
396*da0073e9SAndroid Build Coastguard Worker       register_buffer(name_c, torch::ones(5, torch::kFloat));
397*da0073e9SAndroid Build Coastguard Worker     }
398*da0073e9SAndroid Build Coastguard Worker   };
399*da0073e9SAndroid Build Coastguard Worker   struct A : torch::nn::Module {
400*da0073e9SAndroid Build Coastguard Worker     A(const std::string& name_b, const std::string& name_c) {
401*da0073e9SAndroid Build Coastguard Worker       register_module(name_b, std::make_shared<B>(name_c));
402*da0073e9SAndroid Build Coastguard Worker     }
403*da0073e9SAndroid Build Coastguard Worker   };
404*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
405*da0073e9SAndroid Build Coastguard Worker     M(const std::string& name_a,
406*da0073e9SAndroid Build Coastguard Worker       const std::string& name_b,
407*da0073e9SAndroid Build Coastguard Worker       const std::string& name_c) {
408*da0073e9SAndroid Build Coastguard Worker       register_module(name_a, std::make_shared<A>(name_b, name_c));
409*da0073e9SAndroid Build Coastguard Worker     }
410*da0073e9SAndroid Build Coastguard Worker   };
411*da0073e9SAndroid Build Coastguard Worker 
412*da0073e9SAndroid Build Coastguard Worker   // create a hierarchy of models with names differing below the top level
413*da0073e9SAndroid Build Coastguard Worker   auto model1 = std::make_shared<M>("a", "b", "c");
414*da0073e9SAndroid Build Coastguard Worker   auto model2 = std::make_shared<M>("a", "b", "x");
415*da0073e9SAndroid Build Coastguard Worker   auto model3 = std::make_shared<M>("a", "x", "c");
416*da0073e9SAndroid Build Coastguard Worker 
417*da0073e9SAndroid Build Coastguard Worker   std::stringstream stream;
418*da0073e9SAndroid Build Coastguard Worker   torch::save(model1, stream);
419*da0073e9SAndroid Build Coastguard Worker   // We want the errors to contain hierarchy information, too.
420*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
421*da0073e9SAndroid Build Coastguard Worker       torch::load(model2, stream), "No such serialized tensor 'a.b.x'");
422*da0073e9SAndroid Build Coastguard Worker   stream.seekg(0, stream.beg);
423*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
424*da0073e9SAndroid Build Coastguard Worker       torch::load(model3, stream), "No such serialized submodule: 'a.x'");
425*da0073e9SAndroid Build Coastguard Worker }
426*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,XOR)427*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, XOR) {
428*da0073e9SAndroid Build Coastguard Worker   // We better be able to save and load an XOR model!
429*da0073e9SAndroid Build Coastguard Worker   auto getLoss = [](Sequential model, uint32_t batch_size) {
430*da0073e9SAndroid Build Coastguard Worker     auto inputs = torch::empty({batch_size, 2});
431*da0073e9SAndroid Build Coastguard Worker     auto labels = torch::empty({batch_size});
432*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(batch_size)) {
433*da0073e9SAndroid Build Coastguard Worker       inputs[i] = torch::randint(2, {2}, torch::kInt64);
434*da0073e9SAndroid Build Coastguard Worker       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
435*da0073e9SAndroid Build Coastguard Worker     }
436*da0073e9SAndroid Build Coastguard Worker     auto x = model->forward<torch::Tensor>(inputs);
437*da0073e9SAndroid Build Coastguard Worker     return torch::binary_cross_entropy(x, labels);
438*da0073e9SAndroid Build Coastguard Worker   };
439*da0073e9SAndroid Build Coastguard Worker 
440*da0073e9SAndroid Build Coastguard Worker   auto model = xor_model();
441*da0073e9SAndroid Build Coastguard Worker   auto model2 = xor_model();
442*da0073e9SAndroid Build Coastguard Worker   auto model3 = xor_model();
443*da0073e9SAndroid Build Coastguard Worker   auto optimizer = torch::optim::SGD(
444*da0073e9SAndroid Build Coastguard Worker       model->parameters(),
445*da0073e9SAndroid Build Coastguard Worker       torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
446*da0073e9SAndroid Build Coastguard Worker           1e-6));
447*da0073e9SAndroid Build Coastguard Worker 
448*da0073e9SAndroid Build Coastguard Worker   float running_loss = 1;
449*da0073e9SAndroid Build Coastguard Worker   int epoch = 0;
450*da0073e9SAndroid Build Coastguard Worker   while (running_loss > 0.1) {
451*da0073e9SAndroid Build Coastguard Worker     torch::Tensor loss = getLoss(model, 4);
452*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
453*da0073e9SAndroid Build Coastguard Worker     loss.backward();
454*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
455*da0073e9SAndroid Build Coastguard Worker 
456*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
457*da0073e9SAndroid Build Coastguard Worker     running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
458*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(epoch, 3000);
459*da0073e9SAndroid Build Coastguard Worker     epoch++;
460*da0073e9SAndroid Build Coastguard Worker   }
461*da0073e9SAndroid Build Coastguard Worker 
462*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
463*da0073e9SAndroid Build Coastguard Worker   torch::save(model, tempfile.name);
464*da0073e9SAndroid Build Coastguard Worker   torch::load(model2, tempfile.name);
465*da0073e9SAndroid Build Coastguard Worker 
466*da0073e9SAndroid Build Coastguard Worker   auto loss = getLoss(model2, 100);
467*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(loss.item<float>(), 0.1);
468*da0073e9SAndroid Build Coastguard Worker }
469*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim)470*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim) {
471*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
472*da0073e9SAndroid Build Coastguard Worker   auto model2 = Linear(5, 2);
473*da0073e9SAndroid Build Coastguard Worker   auto model3 = Linear(5, 2);
474*da0073e9SAndroid Build Coastguard Worker 
475*da0073e9SAndroid Build Coastguard Worker   // Models 1, 2, 3 will have the same parameters.
476*da0073e9SAndroid Build Coastguard Worker   auto model_tempfile = c10::make_tempfile();
477*da0073e9SAndroid Build Coastguard Worker   torch::save(model1, model_tempfile.name);
478*da0073e9SAndroid Build Coastguard Worker   torch::load(model2, model_tempfile.name);
479*da0073e9SAndroid Build Coastguard Worker   torch::load(model3, model_tempfile.name);
480*da0073e9SAndroid Build Coastguard Worker 
481*da0073e9SAndroid Build Coastguard Worker   auto param1 = model1->named_parameters();
482*da0073e9SAndroid Build Coastguard Worker   auto param2 = model2->named_parameters();
483*da0073e9SAndroid Build Coastguard Worker   auto param3 = model3->named_parameters();
484*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : param1) {
485*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(p->allclose(param2[p.key()]));
486*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()]));
487*da0073e9SAndroid Build Coastguard Worker   }
488*da0073e9SAndroid Build Coastguard Worker 
489*da0073e9SAndroid Build Coastguard Worker   // Make some optimizers with momentum (and thus state)
490*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::SGD(
491*da0073e9SAndroid Build Coastguard Worker       model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
492*da0073e9SAndroid Build Coastguard Worker   auto optim2 = torch::optim::SGD(
493*da0073e9SAndroid Build Coastguard Worker       model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
494*da0073e9SAndroid Build Coastguard Worker   auto optim2_2 = torch::optim::SGD(
495*da0073e9SAndroid Build Coastguard Worker       model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
496*da0073e9SAndroid Build Coastguard Worker   auto optim3 = torch::optim::SGD(
497*da0073e9SAndroid Build Coastguard Worker       model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
498*da0073e9SAndroid Build Coastguard Worker   auto optim3_2 = torch::optim::SGD(
499*da0073e9SAndroid Build Coastguard Worker       model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9));
500*da0073e9SAndroid Build Coastguard Worker 
501*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
502*da0073e9SAndroid Build Coastguard Worker 
503*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
504*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
505*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
506*da0073e9SAndroid Build Coastguard Worker     y.backward();
507*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
508*da0073e9SAndroid Build Coastguard Worker   };
509*da0073e9SAndroid Build Coastguard Worker 
510*da0073e9SAndroid Build Coastguard Worker   // Do 2 steps of model1
511*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
512*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
513*da0073e9SAndroid Build Coastguard Worker 
514*da0073e9SAndroid Build Coastguard Worker   // Do 2 steps of model 2 without saving the optimizer
515*da0073e9SAndroid Build Coastguard Worker   step(optim2, model2);
516*da0073e9SAndroid Build Coastguard Worker   step(optim2_2, model2);
517*da0073e9SAndroid Build Coastguard Worker 
518*da0073e9SAndroid Build Coastguard Worker   // Do 2 steps of model 3 while saving the optimizer
519*da0073e9SAndroid Build Coastguard Worker   step(optim3, model3);
520*da0073e9SAndroid Build Coastguard Worker 
521*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile = c10::make_tempfile();
522*da0073e9SAndroid Build Coastguard Worker   torch::save(optim3, optim_tempfile.name);
523*da0073e9SAndroid Build Coastguard Worker   torch::load(optim3_2, optim_tempfile.name);
524*da0073e9SAndroid Build Coastguard Worker   step(optim3_2, model3);
525*da0073e9SAndroid Build Coastguard Worker 
526*da0073e9SAndroid Build Coastguard Worker   param1 = model1->named_parameters();
527*da0073e9SAndroid Build Coastguard Worker   param2 = model2->named_parameters();
528*da0073e9SAndroid Build Coastguard Worker   param3 = model3->named_parameters();
529*da0073e9SAndroid Build Coastguard Worker   for (const auto& p : param1) {
530*da0073e9SAndroid Build Coastguard Worker     const auto& name = p.key();
531*da0073e9SAndroid Build Coastguard Worker     // Model 1 and 3 should be the same
532*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
533*da0073e9SAndroid Build Coastguard Worker         param1[name].norm().item<float>() == param3[name].norm().item<float>());
534*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(
535*da0073e9SAndroid Build Coastguard Worker         param1[name].norm().item<float>() != param2[name].norm().item<float>());
536*da0073e9SAndroid Build Coastguard Worker   }
537*da0073e9SAndroid Build Coastguard Worker }
538*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_Adagrad)539*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_Adagrad) {
540*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<Adagrad, AdagradOptions, AdagradParamState>(
541*da0073e9SAndroid Build Coastguard Worker       AdagradOptions(1e-1));
542*da0073e9SAndroid Build Coastguard Worker 
543*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
544*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
545*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::Adagrad(
546*da0073e9SAndroid Build Coastguard Worker       model1->parameters(), torch::optim::AdagradOptions(1e-1));
547*da0073e9SAndroid Build Coastguard Worker 
548*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
549*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
550*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
551*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
552*da0073e9SAndroid Build Coastguard Worker     y.backward();
553*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
554*da0073e9SAndroid Build Coastguard Worker   };
555*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
556*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 =
557*da0073e9SAndroid Build Coastguard Worker       Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1));
558*da0073e9SAndroid Build Coastguard Worker 
559*da0073e9SAndroid Build Coastguard Worker   // fill up with optim1 sum_buffers
560*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> sum_buffers;
561*da0073e9SAndroid Build Coastguard Worker   // fill up with optim1 state_buffers
562*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> step_buffers;
563*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
564*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state = optim1.state();
565*da0073e9SAndroid Build Coastguard Worker   for (const auto& param : params_) {
566*da0073e9SAndroid Build Coastguard Worker     auto key_ = param.unsafeGetTensorImpl();
567*da0073e9SAndroid Build Coastguard Worker     const AdagradParamState& curr_state_ =
568*da0073e9SAndroid Build Coastguard Worker         static_cast<const AdagradParamState&>(*(optim1_state.at(key_).get()));
569*da0073e9SAndroid Build Coastguard Worker     sum_buffers.emplace_back(curr_state_.sum());
570*da0073e9SAndroid Build Coastguard Worker     step_buffers.emplace_back(curr_state_.step());
571*da0073e9SAndroid Build Coastguard Worker   }
572*da0073e9SAndroid Build Coastguard Worker   // write sum_buffers and step_buffers to the file
573*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
574*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
575*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers);
576*da0073e9SAndroid Build Coastguard Worker   write_step_buffers(output_archive, "step_buffers", step_buffers);
577*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
578*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
579*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
580*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<AdagradParamState>(optim1.state(), optim1_2.state());
581*da0073e9SAndroid Build Coastguard Worker }
582*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_SGD)583*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_SGD) {
584*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<SGD, SGDOptions, SGDParamState>(
585*da0073e9SAndroid Build Coastguard Worker       SGDOptions(1e-1).momentum(0.9));
586*da0073e9SAndroid Build Coastguard Worker 
587*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
588*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
589*da0073e9SAndroid Build Coastguard Worker   auto model1_params = model1->parameters();
590*da0073e9SAndroid Build Coastguard Worker   // added a tensor for lazy init check - when all params do not have a momentum
591*da0073e9SAndroid Build Coastguard Worker   // buffer entry
592*da0073e9SAndroid Build Coastguard Worker   model1_params.emplace_back(torch::randn({2, 3}));
593*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::SGD(
594*da0073e9SAndroid Build Coastguard Worker       model1_params, torch::optim::SGDOptions(0.01).momentum(0.9));
595*da0073e9SAndroid Build Coastguard Worker 
596*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
597*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
598*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
599*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
600*da0073e9SAndroid Build Coastguard Worker     y.backward();
601*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
602*da0073e9SAndroid Build Coastguard Worker   };
603*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
604*da0073e9SAndroid Build Coastguard Worker 
605*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> momentum_buffers;
606*da0073e9SAndroid Build Coastguard Worker   int64_t iteration_{0};
607*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
608*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state = optim1.state();
609*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params_.size())) {
610*da0073e9SAndroid Build Coastguard Worker     if (i != (params_.size() - 1)) {
611*da0073e9SAndroid Build Coastguard Worker       auto key_ = params_[i].unsafeGetTensorImpl();
612*da0073e9SAndroid Build Coastguard Worker       const SGDParamState& curr_state_ =
613*da0073e9SAndroid Build Coastguard Worker           static_cast<const SGDParamState&>(*(optim1_state.at(key_).get()));
614*da0073e9SAndroid Build Coastguard Worker       momentum_buffers.emplace_back(curr_state_.momentum_buffer());
615*da0073e9SAndroid Build Coastguard Worker     }
616*da0073e9SAndroid Build Coastguard Worker   }
617*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1));
618*da0073e9SAndroid Build Coastguard Worker   // write momentum_buffers to the file
619*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
620*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
621*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
622*da0073e9SAndroid Build Coastguard Worker       output_archive, "momentum_buffers", momentum_buffers);
623*da0073e9SAndroid Build Coastguard Worker   write_int_value(output_archive, "iteration_", iteration_);
624*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
625*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 =
626*da0073e9SAndroid Build Coastguard Worker       SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9));
627*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
628*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
629*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<SGDParamState>(optim1.state(), optim1_2.state());
630*da0073e9SAndroid Build Coastguard Worker }
631*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_Adam)632*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_Adam) {
633*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<Adam, AdamOptions, AdamParamState>(
634*da0073e9SAndroid Build Coastguard Worker       AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5));
635*da0073e9SAndroid Build Coastguard Worker 
636*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
637*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
638*da0073e9SAndroid Build Coastguard Worker   auto model1_params = model1->parameters();
639*da0073e9SAndroid Build Coastguard Worker   // added a tensor for lazy init check - when all params do not have entry in
640*da0073e9SAndroid Build Coastguard Worker   // buffers
641*da0073e9SAndroid Build Coastguard Worker   model1_params.emplace_back(torch::randn({2, 3}));
642*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::Adam(
643*da0073e9SAndroid Build Coastguard Worker       model1_params, torch::optim::AdamOptions().weight_decay(0.5));
644*da0073e9SAndroid Build Coastguard Worker 
645*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
646*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
647*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
648*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
649*da0073e9SAndroid Build Coastguard Worker     y.backward();
650*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
651*da0073e9SAndroid Build Coastguard Worker   };
652*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
653*da0073e9SAndroid Build Coastguard Worker 
654*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> step_buffers;
655*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> exp_average_buffers;
656*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> exp_average_sq_buffers;
657*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> max_exp_average_sq_buffers;
658*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
659*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state = optim1.state();
660*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params_.size())) {
661*da0073e9SAndroid Build Coastguard Worker     if (i != (params_.size() - 1)) {
662*da0073e9SAndroid Build Coastguard Worker       auto key_ = params_[i].unsafeGetTensorImpl();
663*da0073e9SAndroid Build Coastguard Worker       const AdamParamState& curr_state_ =
664*da0073e9SAndroid Build Coastguard Worker           static_cast<const AdamParamState&>(*(optim1_state.at(key_).get()));
665*da0073e9SAndroid Build Coastguard Worker       step_buffers.emplace_back(curr_state_.step());
666*da0073e9SAndroid Build Coastguard Worker       exp_average_buffers.emplace_back(curr_state_.exp_avg());
667*da0073e9SAndroid Build Coastguard Worker       exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
668*da0073e9SAndroid Build Coastguard Worker       if (curr_state_.max_exp_avg_sq().defined()) {
669*da0073e9SAndroid Build Coastguard Worker         max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
670*da0073e9SAndroid Build Coastguard Worker       }
671*da0073e9SAndroid Build Coastguard Worker     }
672*da0073e9SAndroid Build Coastguard Worker   }
673*da0073e9SAndroid Build Coastguard Worker   // write buffers to the file
674*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
675*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
676*da0073e9SAndroid Build Coastguard Worker   write_step_buffers(output_archive, "step_buffers", step_buffers);
677*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
678*da0073e9SAndroid Build Coastguard Worker       output_archive, "exp_average_buffers", exp_average_buffers);
679*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
680*da0073e9SAndroid Build Coastguard Worker       output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
681*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
682*da0073e9SAndroid Build Coastguard Worker       output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
683*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
684*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions());
685*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
686*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
687*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<AdamParamState>(optim1.state(), optim1_2.state());
688*da0073e9SAndroid Build Coastguard Worker }
689*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_AdamW)690*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_AdamW) {
691*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<AdamW, AdamWOptions, AdamWParamState>(
692*da0073e9SAndroid Build Coastguard Worker       AdamWOptions().lr(0.99999).amsgrad(true).betas(
693*da0073e9SAndroid Build Coastguard Worker           std::make_tuple(0.999, 0.1)));
694*da0073e9SAndroid Build Coastguard Worker 
695*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
696*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
697*da0073e9SAndroid Build Coastguard Worker   auto model1_params = model1->parameters();
698*da0073e9SAndroid Build Coastguard Worker   // added a tensor for lazy init check - when all params do not have entry in
699*da0073e9SAndroid Build Coastguard Worker   // buffers
700*da0073e9SAndroid Build Coastguard Worker   model1_params.emplace_back(torch::randn({2, 3}));
701*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::AdamW(
702*da0073e9SAndroid Build Coastguard Worker       model1_params, torch::optim::AdamWOptions().weight_decay(0.5));
703*da0073e9SAndroid Build Coastguard Worker 
704*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
705*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
706*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
707*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
708*da0073e9SAndroid Build Coastguard Worker     y.backward();
709*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
710*da0073e9SAndroid Build Coastguard Worker   };
711*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
712*da0073e9SAndroid Build Coastguard Worker 
713*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> step_buffers;
714*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> exp_average_buffers;
715*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> exp_average_sq_buffers;
716*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> max_exp_average_sq_buffers;
717*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
718*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state = optim1.state();
719*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params_.size())) {
720*da0073e9SAndroid Build Coastguard Worker     if (i != (params_.size() - 1)) {
721*da0073e9SAndroid Build Coastguard Worker       auto key_ = params_[i].unsafeGetTensorImpl();
722*da0073e9SAndroid Build Coastguard Worker       const AdamWParamState& curr_state_ =
723*da0073e9SAndroid Build Coastguard Worker           static_cast<const AdamWParamState&>(*(optim1_state.at(key_).get()));
724*da0073e9SAndroid Build Coastguard Worker       step_buffers.emplace_back(curr_state_.step());
725*da0073e9SAndroid Build Coastguard Worker       exp_average_buffers.emplace_back(curr_state_.exp_avg());
726*da0073e9SAndroid Build Coastguard Worker       exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq());
727*da0073e9SAndroid Build Coastguard Worker       if (curr_state_.max_exp_avg_sq().defined()) {
728*da0073e9SAndroid Build Coastguard Worker         max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq());
729*da0073e9SAndroid Build Coastguard Worker       }
730*da0073e9SAndroid Build Coastguard Worker     }
731*da0073e9SAndroid Build Coastguard Worker   }
732*da0073e9SAndroid Build Coastguard Worker   // write buffers to the file
733*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
734*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
735*da0073e9SAndroid Build Coastguard Worker   write_step_buffers(output_archive, "step_buffers", step_buffers);
736*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
737*da0073e9SAndroid Build Coastguard Worker       output_archive, "exp_average_buffers", exp_average_buffers);
738*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
739*da0073e9SAndroid Build Coastguard Worker       output_archive, "exp_average_sq_buffers", exp_average_sq_buffers);
740*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
741*da0073e9SAndroid Build Coastguard Worker       output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers);
742*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
743*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions());
744*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
745*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
746*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<AdamWParamState>(optim1.state(), optim1_2.state());
747*da0073e9SAndroid Build Coastguard Worker }
748*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_RMSprop)749*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_RMSprop) {
750*da0073e9SAndroid Build Coastguard Worker   auto options = RMSpropOptions(0.1).momentum(0.9).centered(true);
751*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<RMSprop, RMSpropOptions, RMSpropParamState>(options);
752*da0073e9SAndroid Build Coastguard Worker 
753*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
754*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
755*da0073e9SAndroid Build Coastguard Worker   auto model1_params = model1->parameters();
756*da0073e9SAndroid Build Coastguard Worker 
757*da0073e9SAndroid Build Coastguard Worker   // added a tensor for lazy init check - when all params do not have a momentum
758*da0073e9SAndroid Build Coastguard Worker   // buffer entry
759*da0073e9SAndroid Build Coastguard Worker   model1_params.emplace_back(torch::randn({2, 3}));
760*da0073e9SAndroid Build Coastguard Worker   auto optim1 = torch::optim::RMSprop(model1_params, options);
761*da0073e9SAndroid Build Coastguard Worker 
762*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
763*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
764*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
765*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
766*da0073e9SAndroid Build Coastguard Worker     y.backward();
767*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
768*da0073e9SAndroid Build Coastguard Worker   };
769*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
770*da0073e9SAndroid Build Coastguard Worker 
771*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> square_average_buffers;
772*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> momentum_buffers;
773*da0073e9SAndroid Build Coastguard Worker   std::vector<at::Tensor> grad_average_buffers;
774*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
775*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state = optim1.state();
776*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params_.size())) {
777*da0073e9SAndroid Build Coastguard Worker     if (i != (params_.size() - 1)) {
778*da0073e9SAndroid Build Coastguard Worker       auto key_ = params_[i].unsafeGetTensorImpl();
779*da0073e9SAndroid Build Coastguard Worker       const RMSpropParamState& curr_state_ =
780*da0073e9SAndroid Build Coastguard Worker           static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
781*da0073e9SAndroid Build Coastguard Worker       square_average_buffers.emplace_back(curr_state_.square_avg());
782*da0073e9SAndroid Build Coastguard Worker       if (curr_state_.momentum_buffer().defined()) {
783*da0073e9SAndroid Build Coastguard Worker         momentum_buffers.emplace_back(curr_state_.momentum_buffer());
784*da0073e9SAndroid Build Coastguard Worker       }
785*da0073e9SAndroid Build Coastguard Worker       if (curr_state_.grad_avg().defined()) {
786*da0073e9SAndroid Build Coastguard Worker         grad_average_buffers.emplace_back(curr_state_.grad_avg());
787*da0073e9SAndroid Build Coastguard Worker       }
788*da0073e9SAndroid Build Coastguard Worker     }
789*da0073e9SAndroid Build Coastguard Worker   }
790*da0073e9SAndroid Build Coastguard Worker   // write buffers to the file
791*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
792*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
793*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
794*da0073e9SAndroid Build Coastguard Worker       output_archive, "square_average_buffers", square_average_buffers);
795*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
796*da0073e9SAndroid Build Coastguard Worker       output_archive, "momentum_buffers", momentum_buffers);
797*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(
798*da0073e9SAndroid Build Coastguard Worker       output_archive, "grad_average_buffers", grad_average_buffers);
799*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
800*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 = RMSprop(model1_params, options);
801*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
802*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
803*da0073e9SAndroid Build Coastguard Worker   const auto& params1_2_ = optim1_2.param_groups()[0].params();
804*da0073e9SAndroid Build Coastguard Worker   auto& optim1_2_state = optim1_2.state();
805*da0073e9SAndroid Build Coastguard Worker   // old RMSprop didn't track step value
806*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(params1_2_.size())) {
807*da0073e9SAndroid Build Coastguard Worker     if (i != (params1_2_.size() - 1)) {
808*da0073e9SAndroid Build Coastguard Worker       auto key_ = params_[i].unsafeGetTensorImpl();
809*da0073e9SAndroid Build Coastguard Worker       const RMSpropParamState& curr_state_ =
810*da0073e9SAndroid Build Coastguard Worker           static_cast<const RMSpropParamState&>(*(optim1_state.at(key_).get()));
811*da0073e9SAndroid Build Coastguard Worker       RMSpropParamState& curr_state1_2_ =
812*da0073e9SAndroid Build Coastguard Worker           static_cast<RMSpropParamState&>(*(optim1_2_state.at(key_).get()));
813*da0073e9SAndroid Build Coastguard Worker       curr_state1_2_.step(curr_state_.step());
814*da0073e9SAndroid Build Coastguard Worker     }
815*da0073e9SAndroid Build Coastguard Worker   }
816*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<RMSpropParamState>(optim1.state(), optim1_2.state());
817*da0073e9SAndroid Build Coastguard Worker }
818*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,Optim_LBFGS)819*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, Optim_LBFGS) {
820*da0073e9SAndroid Build Coastguard Worker   test_serialize_optimizer<LBFGS, LBFGSOptions, LBFGSParamState>(
821*da0073e9SAndroid Build Coastguard Worker       LBFGSOptions(), true);
822*da0073e9SAndroid Build Coastguard Worker   // bc compatibility check
823*da0073e9SAndroid Build Coastguard Worker   auto model1 = Linear(5, 2);
824*da0073e9SAndroid Build Coastguard Worker   auto model1_params = model1->parameters();
825*da0073e9SAndroid Build Coastguard Worker   // added a tensor for lazy init check - when all params do not have entry in
826*da0073e9SAndroid Build Coastguard Worker   // buffers
827*da0073e9SAndroid Build Coastguard Worker   model1_params.emplace_back(torch::randn({2, 3}));
828*da0073e9SAndroid Build Coastguard Worker   auto optim1 =
829*da0073e9SAndroid Build Coastguard Worker       torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions());
830*da0073e9SAndroid Build Coastguard Worker 
831*da0073e9SAndroid Build Coastguard Worker   auto x = torch::ones({10, 5});
832*da0073e9SAndroid Build Coastguard Worker   auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) {
833*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
834*da0073e9SAndroid Build Coastguard Worker     auto y = model->forward(x).sum();
835*da0073e9SAndroid Build Coastguard Worker     y.backward();
836*da0073e9SAndroid Build Coastguard Worker     auto closure = []() { return torch::tensor({10}); };
837*da0073e9SAndroid Build Coastguard Worker     optimizer.step(closure);
838*da0073e9SAndroid Build Coastguard Worker   };
839*da0073e9SAndroid Build Coastguard Worker 
840*da0073e9SAndroid Build Coastguard Worker   step(optim1, model1);
841*da0073e9SAndroid Build Coastguard Worker 
842*da0073e9SAndroid Build Coastguard Worker   at::Tensor d, t, H_diag, prev_flat_grad, prev_loss;
843*da0073e9SAndroid Build Coastguard Worker   std::deque<at::Tensor> old_dirs, old_stps;
844*da0073e9SAndroid Build Coastguard Worker 
845*da0073e9SAndroid Build Coastguard Worker   const auto& params_ = optim1.param_groups()[0].params();
846*da0073e9SAndroid Build Coastguard Worker   auto key_ = params_[0].unsafeGetTensorImpl();
847*da0073e9SAndroid Build Coastguard Worker   const auto& optim1_state =
848*da0073e9SAndroid Build Coastguard Worker       static_cast<const LBFGSParamState&>(*(optim1.state().at(key_).get()));
849*da0073e9SAndroid Build Coastguard Worker   d = optim1_state.d();
850*da0073e9SAndroid Build Coastguard Worker   t = at::tensor(optim1_state.t());
851*da0073e9SAndroid Build Coastguard Worker   H_diag = optim1_state.H_diag();
852*da0073e9SAndroid Build Coastguard Worker   prev_flat_grad = optim1_state.prev_flat_grad();
853*da0073e9SAndroid Build Coastguard Worker   prev_loss = at::tensor(optim1_state.prev_loss());
854*da0073e9SAndroid Build Coastguard Worker   old_dirs = optim1_state.old_dirs();
855*da0073e9SAndroid Build Coastguard Worker 
856*da0073e9SAndroid Build Coastguard Worker   // write buffers to the file
857*da0073e9SAndroid Build Coastguard Worker   auto optim_tempfile_old_format = c10::make_tempfile();
858*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
859*da0073e9SAndroid Build Coastguard Worker   output_archive.write("d", d, /*is_buffer=*/true);
860*da0073e9SAndroid Build Coastguard Worker   output_archive.write("t", t, /*is_buffer=*/true);
861*da0073e9SAndroid Build Coastguard Worker   output_archive.write("H_diag", H_diag, /*is_buffer=*/true);
862*da0073e9SAndroid Build Coastguard Worker   output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true);
863*da0073e9SAndroid Build Coastguard Worker   output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true);
864*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(output_archive, "old_dirs", old_dirs);
865*da0073e9SAndroid Build Coastguard Worker   write_tensors_to_archive(output_archive, "old_stps", old_stps);
866*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(optim_tempfile_old_format.name);
867*da0073e9SAndroid Build Coastguard Worker 
868*da0073e9SAndroid Build Coastguard Worker   auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions());
869*da0073e9SAndroid Build Coastguard Worker   OLD_SERIALIZATION_LOGIC_WARNING_CHECK(
870*da0073e9SAndroid Build Coastguard Worker       torch::load, optim1_2, optim_tempfile_old_format.name);
871*da0073e9SAndroid Build Coastguard Worker 
872*da0073e9SAndroid Build Coastguard Worker   const auto& params1_2_ = optim1_2.param_groups()[0].params();
873*da0073e9SAndroid Build Coastguard Worker   auto param_key = params1_2_[0].unsafeGetTensorImpl();
874*da0073e9SAndroid Build Coastguard Worker   auto& optim1_2_state =
875*da0073e9SAndroid Build Coastguard Worker       static_cast<LBFGSParamState&>(*(optim1_2.state().at(param_key).get()));
876*da0073e9SAndroid Build Coastguard Worker 
877*da0073e9SAndroid Build Coastguard Worker   // old LBFGS didn't track func_evals, n_iter, ro, al values
878*da0073e9SAndroid Build Coastguard Worker   optim1_2_state.func_evals(optim1_state.func_evals());
879*da0073e9SAndroid Build Coastguard Worker   optim1_2_state.n_iter(optim1_state.n_iter());
880*da0073e9SAndroid Build Coastguard Worker   optim1_2_state.ro(optim1_state.ro());
881*da0073e9SAndroid Build Coastguard Worker   optim1_2_state.al(optim1_state.al());
882*da0073e9SAndroid Build Coastguard Worker 
883*da0073e9SAndroid Build Coastguard Worker   is_optimizer_state_equal<LBFGSParamState>(optim1.state(), optim1_2.state());
884*da0073e9SAndroid Build Coastguard Worker }
885*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,XOR_CUDA)886*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, XOR_CUDA) {
887*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
888*da0073e9SAndroid Build Coastguard Worker   // We better be able to save and load a XOR model!
889*da0073e9SAndroid Build Coastguard Worker   auto getLoss = [](Sequential model,
890*da0073e9SAndroid Build Coastguard Worker                     uint32_t batch_size,
891*da0073e9SAndroid Build Coastguard Worker                     bool is_cuda = false) {
892*da0073e9SAndroid Build Coastguard Worker     auto inputs = torch::empty({batch_size, 2});
893*da0073e9SAndroid Build Coastguard Worker     auto labels = torch::empty({batch_size});
894*da0073e9SAndroid Build Coastguard Worker     if (is_cuda) {
895*da0073e9SAndroid Build Coastguard Worker       inputs = inputs.cuda();
896*da0073e9SAndroid Build Coastguard Worker       labels = labels.cuda();
897*da0073e9SAndroid Build Coastguard Worker     }
898*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(batch_size)) {
899*da0073e9SAndroid Build Coastguard Worker       inputs[i] = torch::randint(2, {2}, torch::kInt64);
900*da0073e9SAndroid Build Coastguard Worker       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
901*da0073e9SAndroid Build Coastguard Worker     }
902*da0073e9SAndroid Build Coastguard Worker     auto x = model->forward<torch::Tensor>(inputs);
903*da0073e9SAndroid Build Coastguard Worker     return torch::binary_cross_entropy(x, labels);
904*da0073e9SAndroid Build Coastguard Worker   };
905*da0073e9SAndroid Build Coastguard Worker 
906*da0073e9SAndroid Build Coastguard Worker   auto model = xor_model();
907*da0073e9SAndroid Build Coastguard Worker   auto model2 = xor_model();
908*da0073e9SAndroid Build Coastguard Worker   auto model3 = xor_model();
909*da0073e9SAndroid Build Coastguard Worker   auto optimizer = torch::optim::SGD(
910*da0073e9SAndroid Build Coastguard Worker       model->parameters(),
911*da0073e9SAndroid Build Coastguard Worker       torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay(
912*da0073e9SAndroid Build Coastguard Worker           1e-6));
913*da0073e9SAndroid Build Coastguard Worker 
914*da0073e9SAndroid Build Coastguard Worker   float running_loss = 1;
915*da0073e9SAndroid Build Coastguard Worker   int epoch = 0;
916*da0073e9SAndroid Build Coastguard Worker   while (running_loss > 0.1) {
917*da0073e9SAndroid Build Coastguard Worker     torch::Tensor loss = getLoss(model, 4);
918*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
919*da0073e9SAndroid Build Coastguard Worker     loss.backward();
920*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
921*da0073e9SAndroid Build Coastguard Worker 
922*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
923*da0073e9SAndroid Build Coastguard Worker     running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
924*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(epoch, 3000);
925*da0073e9SAndroid Build Coastguard Worker     epoch++;
926*da0073e9SAndroid Build Coastguard Worker   }
927*da0073e9SAndroid Build Coastguard Worker 
928*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
929*da0073e9SAndroid Build Coastguard Worker   torch::save(model, tempfile.name);
930*da0073e9SAndroid Build Coastguard Worker   torch::load(model2, tempfile.name);
931*da0073e9SAndroid Build Coastguard Worker 
932*da0073e9SAndroid Build Coastguard Worker   auto loss = getLoss(model2, 100);
933*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(loss.item<float>(), 0.1);
934*da0073e9SAndroid Build Coastguard Worker 
935*da0073e9SAndroid Build Coastguard Worker   model2->to(torch::kCUDA);
936*da0073e9SAndroid Build Coastguard Worker   loss = getLoss(model2, 100, true);
937*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(loss.item<float>(), 0.1);
938*da0073e9SAndroid Build Coastguard Worker 
939*da0073e9SAndroid Build Coastguard Worker   auto tempfile2 = c10::make_tempfile();
940*da0073e9SAndroid Build Coastguard Worker   torch::save(model2, tempfile2.name);
941*da0073e9SAndroid Build Coastguard Worker   torch::load(model3, tempfile2.name);
942*da0073e9SAndroid Build Coastguard Worker 
943*da0073e9SAndroid Build Coastguard Worker   loss = getLoss(model3, 100, true);
944*da0073e9SAndroid Build Coastguard Worker   ASSERT_LT(loss.item<float>(), 0.1);
945*da0073e9SAndroid Build Coastguard Worker }
946*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers)947*da0073e9SAndroid Build Coastguard Worker TEST(
948*da0073e9SAndroid Build Coastguard Worker     SerializeTest,
949*da0073e9SAndroid Build Coastguard Worker     CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
950*da0073e9SAndroid Build Coastguard Worker   struct C : torch::nn::Module {
951*da0073e9SAndroid Build Coastguard Worker     C() {
952*da0073e9SAndroid Build Coastguard Worker       register_buffer("foo", torch::ones(5, torch::kInt32));
953*da0073e9SAndroid Build Coastguard Worker     }
954*da0073e9SAndroid Build Coastguard Worker   };
955*da0073e9SAndroid Build Coastguard Worker   struct B : torch::nn::Module {};
956*da0073e9SAndroid Build Coastguard Worker   struct A : torch::nn::Module {
957*da0073e9SAndroid Build Coastguard Worker     A() {
958*da0073e9SAndroid Build Coastguard Worker       register_module("b", std::make_shared<B>());
959*da0073e9SAndroid Build Coastguard Worker       register_module("c", std::make_shared<C>());
960*da0073e9SAndroid Build Coastguard Worker     }
961*da0073e9SAndroid Build Coastguard Worker   };
962*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
963*da0073e9SAndroid Build Coastguard Worker     M() {
964*da0073e9SAndroid Build Coastguard Worker       register_module("a", std::make_shared<A>());
965*da0073e9SAndroid Build Coastguard Worker     }
966*da0073e9SAndroid Build Coastguard Worker   };
967*da0073e9SAndroid Build Coastguard Worker 
968*da0073e9SAndroid Build Coastguard Worker   auto out = std::make_shared<M>();
969*da0073e9SAndroid Build Coastguard Worker   std::stringstream ss;
970*da0073e9SAndroid Build Coastguard Worker   torch::save(out, ss);
971*da0073e9SAndroid Build Coastguard Worker   auto in = std::make_shared<M>();
972*da0073e9SAndroid Build Coastguard Worker   torch::load(in, ss);
973*da0073e9SAndroid Build Coastguard Worker 
974*da0073e9SAndroid Build Coastguard Worker   const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
975*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output, 5);
976*da0073e9SAndroid Build Coastguard Worker }
977*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,VectorOfTensors)978*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, VectorOfTensors) {
979*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
980*da0073e9SAndroid Build Coastguard Worker 
981*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> x_vec = {
982*da0073e9SAndroid Build Coastguard Worker       torch::randn({1, 2}), torch::randn({3, 4})};
983*da0073e9SAndroid Build Coastguard Worker 
984*da0073e9SAndroid Build Coastguard Worker   std::stringstream stream;
985*da0073e9SAndroid Build Coastguard Worker   torch::save(x_vec, stream);
986*da0073e9SAndroid Build Coastguard Worker 
987*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> y_vec;
988*da0073e9SAndroid Build Coastguard Worker   torch::load(y_vec, stream);
989*da0073e9SAndroid Build Coastguard Worker 
990*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(x_vec.size())) {
991*da0073e9SAndroid Build Coastguard Worker     auto& x = x_vec[i];
992*da0073e9SAndroid Build Coastguard Worker     auto& y = y_vec[i];
993*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(y.defined());
994*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(x.sizes().vec(), y.sizes().vec());
995*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(x.allclose(y));
996*da0073e9SAndroid Build Coastguard Worker   }
997*da0073e9SAndroid Build Coastguard Worker }
998*da0073e9SAndroid Build Coastguard Worker 
TEST(SerializeTest,IValue)999*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, IValue) {
1000*da0073e9SAndroid Build Coastguard Worker   c10::IValue ivalue(1);
1001*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
1002*da0073e9SAndroid Build Coastguard Worker   torch::serialize::OutputArchive output_archive;
1003*da0073e9SAndroid Build Coastguard Worker   output_archive.write("value", ivalue);
1004*da0073e9SAndroid Build Coastguard Worker   output_archive.save_to(tempfile.name);
1005*da0073e9SAndroid Build Coastguard Worker 
1006*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive input_archive;
1007*da0073e9SAndroid Build Coastguard Worker   input_archive.load_from(tempfile.name);
1008*da0073e9SAndroid Build Coastguard Worker   c10::IValue ivalue_out;
1009*da0073e9SAndroid Build Coastguard Worker   input_archive.read("value", ivalue_out);
1010*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(ivalue_out.toInt(), 1);
1011*da0073e9SAndroid Build Coastguard Worker 
1012*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
1013*da0073e9SAndroid Build Coastguard Worker       input_archive.read("bad_key", ivalue_out),
1014*da0073e9SAndroid Build Coastguard Worker       "does not have a field with name");
1015*da0073e9SAndroid Build Coastguard Worker }
1016*da0073e9SAndroid Build Coastguard Worker 
1017*da0073e9SAndroid Build Coastguard Worker // NOTE: if a `Module` contains unserializable submodules (e.g.
1018*da0073e9SAndroid Build Coastguard Worker // `nn::Functional`), we expect those submodules to be skipped when the `Module`
1019*da0073e9SAndroid Build Coastguard Worker // is being serialized.
TEST(SerializeTest,UnserializableSubmoduleIsSkippedWhenSavingModule)1020*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) {
1021*da0073e9SAndroid Build Coastguard Worker   struct A : torch::nn::Module {
1022*da0073e9SAndroid Build Coastguard Worker     A() {
1023*da0073e9SAndroid Build Coastguard Worker       register_module("relu", torch::nn::Functional(torch::relu));
1024*da0073e9SAndroid Build Coastguard Worker     }
1025*da0073e9SAndroid Build Coastguard Worker   };
1026*da0073e9SAndroid Build Coastguard Worker 
1027*da0073e9SAndroid Build Coastguard Worker   auto out = std::make_shared<A>();
1028*da0073e9SAndroid Build Coastguard Worker   std::stringstream ss;
1029*da0073e9SAndroid Build Coastguard Worker   torch::save(out, ss);
1030*da0073e9SAndroid Build Coastguard Worker 
1031*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive archive;
1032*da0073e9SAndroid Build Coastguard Worker   archive.load_from(ss);
1033*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive relu_archive;
1034*da0073e9SAndroid Build Coastguard Worker 
1035*da0073e9SAndroid Build Coastguard Worker   // Submodule with name "relu" should not exist in the `InputArchive`,
1036*da0073e9SAndroid Build Coastguard Worker   // because the "relu" submodule is an `nn::Functional` and is not
1037*da0073e9SAndroid Build Coastguard Worker   // serializable.
1038*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(archive.try_read("relu", relu_archive));
1039*da0073e9SAndroid Build Coastguard Worker }
1040*da0073e9SAndroid Build Coastguard Worker 
1041*da0073e9SAndroid Build Coastguard Worker // NOTE: If a `Module` contains unserializable submodules (e.g.
1042*da0073e9SAndroid Build Coastguard Worker // `nn::Functional`), we don't check the existence of those submodules in the
1043*da0073e9SAndroid Build Coastguard Worker // `InputArchive` when deserializing.
TEST(SerializeTest,UnserializableSubmoduleIsIgnoredWhenLoadingModule)1044*da0073e9SAndroid Build Coastguard Worker TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) {
1045*da0073e9SAndroid Build Coastguard Worker   struct B : torch::nn::Module {
1046*da0073e9SAndroid Build Coastguard Worker     B() {
1047*da0073e9SAndroid Build Coastguard Worker       register_module("relu1", torch::nn::Functional(torch::relu));
1048*da0073e9SAndroid Build Coastguard Worker       register_buffer("foo", torch::zeros(5, torch::kInt32));
1049*da0073e9SAndroid Build Coastguard Worker     }
1050*da0073e9SAndroid Build Coastguard Worker   };
1051*da0073e9SAndroid Build Coastguard Worker   struct A : torch::nn::Module {
1052*da0073e9SAndroid Build Coastguard Worker     A() {
1053*da0073e9SAndroid Build Coastguard Worker       register_module("b", std::make_shared<B>());
1054*da0073e9SAndroid Build Coastguard Worker       register_module("relu2", torch::nn::Functional(torch::relu));
1055*da0073e9SAndroid Build Coastguard Worker     }
1056*da0073e9SAndroid Build Coastguard Worker   };
1057*da0073e9SAndroid Build Coastguard Worker 
1058*da0073e9SAndroid Build Coastguard Worker   auto out = std::make_shared<A>();
1059*da0073e9SAndroid Build Coastguard Worker   // Manually change the values of "b.foo", so that we can check whether the
1060*da0073e9SAndroid Build Coastguard Worker   // buffer contains these values after deserialization.
1061*da0073e9SAndroid Build Coastguard Worker   out->named_buffers()["b.foo"].fill_(1);
1062*da0073e9SAndroid Build Coastguard Worker   auto tempfile = c10::make_tempfile();
1063*da0073e9SAndroid Build Coastguard Worker   torch::save(out, tempfile.name);
1064*da0073e9SAndroid Build Coastguard Worker 
1065*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive archive;
1066*da0073e9SAndroid Build Coastguard Worker   archive.load_from(tempfile.name);
1067*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive archive_b;
1068*da0073e9SAndroid Build Coastguard Worker   torch::serialize::InputArchive archive_relu;
1069*da0073e9SAndroid Build Coastguard Worker   torch::Tensor tensor_foo;
1070*da0073e9SAndroid Build Coastguard Worker 
1071*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(archive.try_read("b", archive_b));
1072*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true));
1073*da0073e9SAndroid Build Coastguard Worker 
1074*da0073e9SAndroid Build Coastguard Worker   // Submodule with name "relu1" should not exist in `archive_b`, because the
1075*da0073e9SAndroid Build Coastguard Worker   // "relu1" submodule is an `nn::Functional` and is not serializable.
1076*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(archive_b.try_read("relu1", archive_relu));
1077*da0073e9SAndroid Build Coastguard Worker 
1078*da0073e9SAndroid Build Coastguard Worker   // Submodule with name "relu2" should not exist in `archive`, because the
1079*da0073e9SAndroid Build Coastguard Worker   // "relu2" submodule is an `nn::Functional` and is not serializable.
1080*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(archive.try_read("relu2", archive_relu));
1081*da0073e9SAndroid Build Coastguard Worker 
1082*da0073e9SAndroid Build Coastguard Worker   auto in = std::make_shared<A>();
1083*da0073e9SAndroid Build Coastguard Worker   // `torch::load(...)` works without error, even though `A` contains the
1084*da0073e9SAndroid Build Coastguard Worker   // `nn::Functional` submodules while the serialized file doesn't, because the
1085*da0073e9SAndroid Build Coastguard Worker   // `nn::Functional` submodules are not serializable and thus ignored when
1086*da0073e9SAndroid Build Coastguard Worker   // deserializing.
1087*da0073e9SAndroid Build Coastguard Worker   torch::load(in, tempfile.name);
1088*da0073e9SAndroid Build Coastguard Worker 
1089*da0073e9SAndroid Build Coastguard Worker   // Check that the "b.foo" buffer is correctly deserialized from the file.
1090*da0073e9SAndroid Build Coastguard Worker   const int output = in->named_buffers()["b.foo"].sum().item<int>();
1091*da0073e9SAndroid Build Coastguard Worker   // `output` should equal to the sum of the values we manually assigned to
1092*da0073e9SAndroid Build Coastguard Worker   // "b.foo" before serialization.
1093*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(output, 5);
1094*da0073e9SAndroid Build Coastguard Worker }
1095