#include #include #include #include #include #include #include #include #include #include #include using namespace torch::test; using namespace torch::nn; using namespace torch::optim; namespace { Sequential xor_model() { return Sequential( Linear(2, 8), Functional(at::sigmoid), Linear(8, 1), Functional(at::sigmoid)); } torch::Tensor save_and_load(torch::Tensor input) { std::stringstream stream; torch::save(input, stream); torch::Tensor tensor; torch::load(tensor, stream); return tensor; } } // namespace template void is_optimizer_param_group_equal( const OptimizerParamGroup& lhs, const OptimizerParamGroup& rhs) { const auto& lhs_params = lhs.params(); const auto& rhs_params = rhs.params(); ASSERT_TRUE(lhs_params.size() == rhs_params.size()); for (const auto j : c10::irange(lhs_params.size())) { ASSERT_TRUE(torch::equal(lhs_params[j], rhs_params[j])); } ASSERT_TRUE( static_cast(lhs.options()) == static_cast(rhs.options())); } template void is_optimizer_state_equal( const ska::flat_hash_map>& lhs_state, const ska::flat_hash_map>& rhs_state) { ASSERT_TRUE(lhs_state.size() == rhs_state.size()); for (const auto& value : lhs_state) { auto found = rhs_state.find(value.first); ASSERT_TRUE(found != rhs_state.end()); const DerivedOptimizerParamState& lhs_curr_state = static_cast(*(value.second.get())); const DerivedOptimizerParamState& rhs_curr_state = static_cast(*(found->second.get())); ASSERT_TRUE(lhs_curr_state == rhs_curr_state); } } template < typename OptimizerClass, typename DerivedOptimizerOptions, typename DerivedOptimizerParamState> void test_serialize_optimizer( DerivedOptimizerOptions options, bool only_has_global_state = false) { torch::manual_seed(0); auto model1 = Linear(5, 2); auto model2 = Linear(5, 2); auto model3 = Linear(5, 2); // Models 1, 2, 3 will have the same parameters. auto model_tempfile = c10::make_tempfile(); torch::save(model1, model_tempfile.name); torch::load(model2, model_tempfile.name); torch::load(model3, model_tempfile.name); auto param1 = model1->named_parameters(); auto param2 = model2->named_parameters(); auto param3 = model3->named_parameters(); for (const auto& p : param1) { ASSERT_TRUE(p->allclose(param2[p.key()])); ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()])); } // Make some optimizers auto optim1 = OptimizerClass( {torch::optim::OptimizerParamGroup(model1->parameters())}, options); auto optim2 = OptimizerClass(model2->parameters(), options); auto optim2_2 = OptimizerClass(model2->parameters(), options); auto optim3 = OptimizerClass(model3->parameters(), options); auto optim3_2 = OptimizerClass(model3->parameters(), options); for (auto& param_group : optim3_2.param_groups()) { const double lr = param_group.options().get_lr(); // change the learning rate, which will be overwritten by the loading // otherwise, test cannot check if options are saved and loaded correctly param_group.options().set_lr(lr + 0.01); } auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); auto closure = []() { return torch::tensor({10}); }; optimizer.step(closure); }; // Do 2 steps of model1 step(optim1, model1); step(optim1, model1); // Do 2 steps of model 2 without saving the optimizer step(optim2, model2); step(optim2_2, model2); // Do 1 step of model 3 step(optim3, model3); // save the optimizer auto optim_tempfile = c10::make_tempfile(); torch::save(optim3, optim_tempfile.name); torch::load(optim3_2, optim_tempfile.name); auto& optim3_2_param_groups = optim3_2.param_groups(); auto& optim3_param_groups = optim3.param_groups(); auto& optim3_2_state = optim3_2.state(); auto& optim3_state = optim3.state(); // optim3_2 and optim1 should have param_groups and state of size 1 and // state_size respectively ASSERT_TRUE(optim3_2_param_groups.size() == 1); // state_size = 2 for all optimizers except LBFGS as LBFGS only maintains one // global state unsigned state_size = only_has_global_state ? 1 : 2; ASSERT_TRUE(optim3_2_state.size() == state_size); // optim3_2 and optim1 should have param_groups and state of same size ASSERT_TRUE(optim3_2_param_groups.size() == optim3_param_groups.size()); ASSERT_TRUE(optim3_2_state.size() == optim3_state.size()); // checking correctness of serialization logic for optimizer.param_groups_ and // optimizer.state_ for (const auto i : c10::irange(optim3_2_param_groups.size())) { is_optimizer_param_group_equal( optim3_2_param_groups[i], optim3_param_groups[i]); is_optimizer_state_equal( optim3_2_state, optim3_state); } // Do step2 for model 3 step(optim3_2, model3); param1 = model1->named_parameters(); param2 = model2->named_parameters(); param3 = model3->named_parameters(); for (const auto& p : param1) { const auto& name = p.key(); // Model 1 and 3 should be the same ASSERT_TRUE( param1[name].norm().item() == param3[name].norm().item()); ASSERT_TRUE( param1[name].norm().item() != param2[name].norm().item()); } } /// Utility function to save a value of `int64_t` type. void write_int_value( torch::serialize::OutputArchive& archive, const std::string& key, const int64_t& value) { archive.write(key, c10::IValue(value)); } // Utility function to save a vector of buffers. template void write_tensors_to_archive( torch::serialize::OutputArchive& archive, const std::string& key, const BufferContainer& buffers) { archive.write( key + "/size", torch::tensor(static_cast(buffers.size()))); for (const auto index : c10::irange(buffers.size())) { archive.write( key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true); } } // Utility function to save a vector of step buffers. void write_step_buffers( torch::serialize::OutputArchive& archive, const std::string& key, const std::vector& steps) { std::vector tensors; tensors.reserve(steps.size()); for (const auto& step : steps) { tensors.push_back(torch::tensor(static_cast(step))); } write_tensors_to_archive(archive, key, tensors); } #define OLD_SERIALIZATION_LOGIC_WARNING_CHECK(funcname, optimizer, filename) \ { \ WarningCapture warnings; \ funcname(optimizer, filename); \ ASSERT_EQ( \ count_substr_occurrences(warnings.str(), "old serialization"), 1); \ } TEST(SerializeTest, KeysFunc) { auto tempfile = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; for (const auto i : c10::irange(3)) { output_archive.write( "element/" + std::to_string(i), c10::IValue(static_cast(i))); } output_archive.save_to(tempfile.name); torch::serialize::InputArchive input_archive; input_archive.load_from(tempfile.name); std::vector keys = input_archive.keys(); ASSERT_EQ(keys.size(), 3); for (const auto i : c10::irange(keys.size())) { ASSERT_EQ(keys[i], "element/" + std::to_string(i)); } } TEST(SerializeTest, TryReadFunc) { auto tempfile = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; for (const auto i : c10::irange(3)) { output_archive.write( "element/" + std::to_string(i), c10::IValue(static_cast(i))); } output_archive.save_to(tempfile.name); torch::serialize::InputArchive input_archive; input_archive.load_from(tempfile.name); c10::IValue ivalue; ASSERT_FALSE(input_archive.try_read("1", ivalue)); ASSERT_TRUE(input_archive.try_read("element/1", ivalue)); ASSERT_EQ(ivalue.toInt(), 1); } TEST(SerializeTest, Basic) { torch::manual_seed(0); auto x = torch::randn({5, 5}); auto y = save_and_load(x); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } TEST(SerializeTest, MathBits) { torch::manual_seed(0); auto options = torch::TensorOptions{}.dtype(torch::kComplexFloat); auto x = torch::randn({5, 5}, options); { auto expected = torch::conj(x); auto actual = save_and_load(expected); ASSERT_TRUE(actual.defined()); ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } { auto expected = torch::_neg_view(x); auto actual = save_and_load(expected); ASSERT_TRUE(actual.defined()); ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } { auto expected = torch::conj(torch::_neg_view(x)); auto actual = save_and_load(expected); ASSERT_TRUE(actual.defined()); ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } { // We don't support serializing `ZeroTensor` as it is not public facing yet. // If in future, `ZeroTensor` serialization is supported, this test should // start failing! auto t = torch::_efficientzerotensor({5, 5}); ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,"); } } TEST(SerializeTest, BasicToFile) { torch::manual_seed(0); auto x = torch::randn({5, 5}); auto tempfile = c10::make_tempfile(); torch::save(x, tempfile.name); torch::Tensor y; torch::load(y, tempfile.name); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } TEST(SerializeTest, BasicViaFunc) { torch::manual_seed(0); auto x = torch::randn({5, 5}); std::string serialized; torch::save(x, [&](const void* buf, size_t n) { serialized.append(reinterpret_cast(buf), n); return n; }); torch::Tensor y; torch::load(y, serialized.data(), serialized.size()); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); torch::Tensor z; torch::load( z, [&](uint64_t pos, void* buf, size_t n) -> size_t { if (pos >= serialized.size()) return 0; size_t nbytes = std::min(static_cast(pos) + n, serialized.size()) - pos; memcpy(buf, serialized.data() + pos, nbytes); return nbytes; }, [&]() -> size_t { return serialized.size(); }); ASSERT_TRUE(z.defined()); ASSERT_EQ(x.sizes().vec(), z.sizes().vec()); ASSERT_TRUE(x.allclose(z)); } TEST(SerializeTest, Resized) { torch::manual_seed(0); auto x = torch::randn({11, 5}); x.resize_({5, 5}); auto y = save_and_load(x); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } TEST(SerializeTest, Sliced) { torch::manual_seed(0); auto x = torch::randn({11, 5}); x = x.slice(0, 1, 5); auto y = save_and_load(x); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } TEST(SerializeTest, NonContiguous) { torch::manual_seed(0); auto x = torch::randn({11, 5}); x = x.slice(1, 1, 4); auto y = save_and_load(x); ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } TEST(SerializeTest, ErrorOnMissingKey) { struct B : torch::nn::Module { B(const std::string& name_c) { register_buffer(name_c, torch::ones(5, torch::kFloat)); } }; struct A : torch::nn::Module { A(const std::string& name_b, const std::string& name_c) { register_module(name_b, std::make_shared(name_c)); } }; struct M : torch::nn::Module { M(const std::string& name_a, const std::string& name_b, const std::string& name_c) { register_module(name_a, std::make_shared(name_b, name_c)); } }; // create a hierarchy of models with names differing below the top level auto model1 = std::make_shared("a", "b", "c"); auto model2 = std::make_shared("a", "b", "x"); auto model3 = std::make_shared("a", "x", "c"); std::stringstream stream; torch::save(model1, stream); // We want the errors to contain hierarchy information, too. ASSERT_THROWS_WITH( torch::load(model2, stream), "No such serialized tensor 'a.b.x'"); stream.seekg(0, stream.beg); ASSERT_THROWS_WITH( torch::load(model3, stream), "No such serialized submodule: 'a.x'"); } TEST(SerializeTest, XOR) { // We better be able to save and load an XOR model! auto getLoss = [](Sequential model, uint32_t batch_size) { auto inputs = torch::empty({batch_size, 2}); auto labels = torch::empty({batch_size}); for (const auto i : c10::irange(batch_size)) { inputs[i] = torch::randint(2, {2}, torch::kInt64); labels[i] = inputs[i][0].item() ^ inputs[i][1].item(); } auto x = model->forward(inputs); return torch::binary_cross_entropy(x, labels); }; auto model = xor_model(); auto model2 = xor_model(); auto model3 = xor_model(); auto optimizer = torch::optim::SGD( model->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay( 1e-6)); float running_loss = 1; int epoch = 0; while (running_loss > 0.1) { torch::Tensor loss = getLoss(model, 4); optimizer.zero_grad(); loss.backward(); optimizer.step(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) running_loss = running_loss * 0.99 + loss.sum().item() * 0.01; ASSERT_LT(epoch, 3000); epoch++; } auto tempfile = c10::make_tempfile(); torch::save(model, tempfile.name); torch::load(model2, tempfile.name); auto loss = getLoss(model2, 100); ASSERT_LT(loss.item(), 0.1); } TEST(SerializeTest, Optim) { auto model1 = Linear(5, 2); auto model2 = Linear(5, 2); auto model3 = Linear(5, 2); // Models 1, 2, 3 will have the same parameters. auto model_tempfile = c10::make_tempfile(); torch::save(model1, model_tempfile.name); torch::load(model2, model_tempfile.name); torch::load(model3, model_tempfile.name); auto param1 = model1->named_parameters(); auto param2 = model2->named_parameters(); auto param3 = model3->named_parameters(); for (const auto& p : param1) { ASSERT_TRUE(p->allclose(param2[p.key()])); ASSERT_TRUE(param2[p.key()].allclose(param3[p.key()])); } // Make some optimizers with momentum (and thus state) auto optim1 = torch::optim::SGD( model1->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); auto optim2 = torch::optim::SGD( model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); auto optim2_2 = torch::optim::SGD( model2->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); auto optim3 = torch::optim::SGD( model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); auto optim3_2 = torch::optim::SGD( model3->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; // Do 2 steps of model1 step(optim1, model1); step(optim1, model1); // Do 2 steps of model 2 without saving the optimizer step(optim2, model2); step(optim2_2, model2); // Do 2 steps of model 3 while saving the optimizer step(optim3, model3); auto optim_tempfile = c10::make_tempfile(); torch::save(optim3, optim_tempfile.name); torch::load(optim3_2, optim_tempfile.name); step(optim3_2, model3); param1 = model1->named_parameters(); param2 = model2->named_parameters(); param3 = model3->named_parameters(); for (const auto& p : param1) { const auto& name = p.key(); // Model 1 and 3 should be the same ASSERT_TRUE( param1[name].norm().item() == param3[name].norm().item()); ASSERT_TRUE( param1[name].norm().item() != param2[name].norm().item()); } } TEST(SerializeTest, Optim_Adagrad) { test_serialize_optimizer( AdagradOptions(1e-1)); // bc compatibility check auto model1 = Linear(5, 2); auto optim1 = torch::optim::Adagrad( model1->parameters(), torch::optim::AdagradOptions(1e-1)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; step(optim1, model1); auto optim1_2 = Adagrad(model1->parameters(), torch::optim::AdagradOptions(1e-1)); // fill up with optim1 sum_buffers std::vector sum_buffers; // fill up with optim1 state_buffers std::vector step_buffers; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto& param : params_) { auto key_ = param.unsafeGetTensorImpl(); const AdagradParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); sum_buffers.emplace_back(curr_state_.sum()); step_buffers.emplace_back(curr_state_.step()); } // write sum_buffers and step_buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_tensors_to_archive(output_archive, "sum_buffers", sum_buffers); write_step_buffers(output_archive, "step_buffers", step_buffers); output_archive.save_to(optim_tempfile_old_format.name); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_SGD) { test_serialize_optimizer( SGDOptions(1e-1).momentum(0.9)); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); // added a tensor for lazy init check - when all params do not have a momentum // buffer entry model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::SGD( model1_params, torch::optim::SGDOptions(0.01).momentum(0.9)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; step(optim1, model1); std::vector momentum_buffers; int64_t iteration_{0}; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { if (i != (params_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); const SGDParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); momentum_buffers.emplace_back(curr_state_.momentum_buffer()); } } ASSERT_TRUE(momentum_buffers.size() == (params_.size() - 1)); // write momentum_buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_tensors_to_archive( output_archive, "momentum_buffers", momentum_buffers); write_int_value(output_archive, "iteration_", iteration_); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = SGD(model1_params, torch::optim::SGDOptions(1e-1).momentum(0.9)); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_Adam) { test_serialize_optimizer( AdamOptions().lr(0.99999).amsgrad(true).weight_decay(0.5)); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); // added a tensor for lazy init check - when all params do not have entry in // buffers model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::Adam( model1_params, torch::optim::AdamOptions().weight_decay(0.5)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; step(optim1, model1); std::vector step_buffers; std::vector exp_average_buffers; std::vector exp_average_sq_buffers; std::vector max_exp_average_sq_buffers; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { if (i != (params_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); const AdamParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); step_buffers.emplace_back(curr_state_.step()); exp_average_buffers.emplace_back(curr_state_.exp_avg()); exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); if (curr_state_.max_exp_avg_sq().defined()) { max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); } } } // write buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_step_buffers(output_archive, "step_buffers", step_buffers); write_tensors_to_archive( output_archive, "exp_average_buffers", exp_average_buffers); write_tensors_to_archive( output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); write_tensors_to_archive( output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = Adam(model1_params, torch::optim::AdamOptions()); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_AdamW) { test_serialize_optimizer( AdamWOptions().lr(0.99999).amsgrad(true).betas( std::make_tuple(0.999, 0.1))); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); // added a tensor for lazy init check - when all params do not have entry in // buffers model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::AdamW( model1_params, torch::optim::AdamWOptions().weight_decay(0.5)); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; step(optim1, model1); std::vector step_buffers; std::vector exp_average_buffers; std::vector exp_average_sq_buffers; std::vector max_exp_average_sq_buffers; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { if (i != (params_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); const AdamWParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); step_buffers.emplace_back(curr_state_.step()); exp_average_buffers.emplace_back(curr_state_.exp_avg()); exp_average_sq_buffers.emplace_back(curr_state_.exp_avg_sq()); if (curr_state_.max_exp_avg_sq().defined()) { max_exp_average_sq_buffers.emplace_back(curr_state_.max_exp_avg_sq()); } } } // write buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_step_buffers(output_archive, "step_buffers", step_buffers); write_tensors_to_archive( output_archive, "exp_average_buffers", exp_average_buffers); write_tensors_to_archive( output_archive, "exp_average_sq_buffers", exp_average_sq_buffers); write_tensors_to_archive( output_archive, "max_exp_average_sq_buffers", max_exp_average_sq_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = AdamW(model1_params, torch::optim::AdamWOptions()); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_RMSprop) { auto options = RMSpropOptions(0.1).momentum(0.9).centered(true); test_serialize_optimizer(options); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); // added a tensor for lazy init check - when all params do not have a momentum // buffer entry model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::RMSprop(model1_params, options); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); optimizer.step(); }; step(optim1, model1); std::vector square_average_buffers; std::vector momentum_buffers; std::vector grad_average_buffers; const auto& params_ = optim1.param_groups()[0].params(); const auto& optim1_state = optim1.state(); for (const auto i : c10::irange(params_.size())) { if (i != (params_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); square_average_buffers.emplace_back(curr_state_.square_avg()); if (curr_state_.momentum_buffer().defined()) { momentum_buffers.emplace_back(curr_state_.momentum_buffer()); } if (curr_state_.grad_avg().defined()) { grad_average_buffers.emplace_back(curr_state_.grad_avg()); } } } // write buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; write_tensors_to_archive( output_archive, "square_average_buffers", square_average_buffers); write_tensors_to_archive( output_archive, "momentum_buffers", momentum_buffers); write_tensors_to_archive( output_archive, "grad_average_buffers", grad_average_buffers); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = RMSprop(model1_params, options); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); const auto& params1_2_ = optim1_2.param_groups()[0].params(); auto& optim1_2_state = optim1_2.state(); // old RMSprop didn't track step value for (const auto i : c10::irange(params1_2_.size())) { if (i != (params1_2_.size() - 1)) { auto key_ = params_[i].unsafeGetTensorImpl(); const RMSpropParamState& curr_state_ = static_cast(*(optim1_state.at(key_).get())); RMSpropParamState& curr_state1_2_ = static_cast(*(optim1_2_state.at(key_).get())); curr_state1_2_.step(curr_state_.step()); } } is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, Optim_LBFGS) { test_serialize_optimizer( LBFGSOptions(), true); // bc compatibility check auto model1 = Linear(5, 2); auto model1_params = model1->parameters(); // added a tensor for lazy init check - when all params do not have entry in // buffers model1_params.emplace_back(torch::randn({2, 3})); auto optim1 = torch::optim::LBFGS(model1_params, torch::optim::LBFGSOptions()); auto x = torch::ones({10, 5}); auto step = [&x](torch::optim::Optimizer& optimizer, Linear model) { optimizer.zero_grad(); auto y = model->forward(x).sum(); y.backward(); auto closure = []() { return torch::tensor({10}); }; optimizer.step(closure); }; step(optim1, model1); at::Tensor d, t, H_diag, prev_flat_grad, prev_loss; std::deque old_dirs, old_stps; const auto& params_ = optim1.param_groups()[0].params(); auto key_ = params_[0].unsafeGetTensorImpl(); const auto& optim1_state = static_cast(*(optim1.state().at(key_).get())); d = optim1_state.d(); t = at::tensor(optim1_state.t()); H_diag = optim1_state.H_diag(); prev_flat_grad = optim1_state.prev_flat_grad(); prev_loss = at::tensor(optim1_state.prev_loss()); old_dirs = optim1_state.old_dirs(); // write buffers to the file auto optim_tempfile_old_format = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; output_archive.write("d", d, /*is_buffer=*/true); output_archive.write("t", t, /*is_buffer=*/true); output_archive.write("H_diag", H_diag, /*is_buffer=*/true); output_archive.write("prev_flat_grad", prev_flat_grad, /*is_buffer=*/true); output_archive.write("prev_loss", prev_loss, /*is_buffer=*/true); write_tensors_to_archive(output_archive, "old_dirs", old_dirs); write_tensors_to_archive(output_archive, "old_stps", old_stps); output_archive.save_to(optim_tempfile_old_format.name); auto optim1_2 = LBFGS(model1_params, torch::optim::LBFGSOptions()); OLD_SERIALIZATION_LOGIC_WARNING_CHECK( torch::load, optim1_2, optim_tempfile_old_format.name); const auto& params1_2_ = optim1_2.param_groups()[0].params(); auto param_key = params1_2_[0].unsafeGetTensorImpl(); auto& optim1_2_state = static_cast(*(optim1_2.state().at(param_key).get())); // old LBFGS didn't track func_evals, n_iter, ro, al values optim1_2_state.func_evals(optim1_state.func_evals()); optim1_2_state.n_iter(optim1_state.n_iter()); optim1_2_state.ro(optim1_state.ro()); optim1_2_state.al(optim1_state.al()); is_optimizer_state_equal(optim1.state(), optim1_2.state()); } TEST(SerializeTest, XOR_CUDA) { torch::manual_seed(0); // We better be able to save and load a XOR model! auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda = false) { auto inputs = torch::empty({batch_size, 2}); auto labels = torch::empty({batch_size}); if (is_cuda) { inputs = inputs.cuda(); labels = labels.cuda(); } for (const auto i : c10::irange(batch_size)) { inputs[i] = torch::randint(2, {2}, torch::kInt64); labels[i] = inputs[i][0].item() ^ inputs[i][1].item(); } auto x = model->forward(inputs); return torch::binary_cross_entropy(x, labels); }; auto model = xor_model(); auto model2 = xor_model(); auto model3 = xor_model(); auto optimizer = torch::optim::SGD( model->parameters(), torch::optim::SGDOptions(1e-1).momentum(0.9).nesterov(true).weight_decay( 1e-6)); float running_loss = 1; int epoch = 0; while (running_loss > 0.1) { torch::Tensor loss = getLoss(model, 4); optimizer.zero_grad(); loss.backward(); optimizer.step(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) running_loss = running_loss * 0.99 + loss.sum().item() * 0.01; ASSERT_LT(epoch, 3000); epoch++; } auto tempfile = c10::make_tempfile(); torch::save(model, tempfile.name); torch::load(model2, tempfile.name); auto loss = getLoss(model2, 100); ASSERT_LT(loss.item(), 0.1); model2->to(torch::kCUDA); loss = getLoss(model2, 100, true); ASSERT_LT(loss.item(), 0.1); auto tempfile2 = c10::make_tempfile(); torch::save(model2, tempfile2.name); torch::load(model3, tempfile2.name); loss = getLoss(model3, 100, true); ASSERT_LT(loss.item(), 0.1); } TEST( SerializeTest, CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) { struct C : torch::nn::Module { C() { register_buffer("foo", torch::ones(5, torch::kInt32)); } }; struct B : torch::nn::Module {}; struct A : torch::nn::Module { A() { register_module("b", std::make_shared()); register_module("c", std::make_shared()); } }; struct M : torch::nn::Module { M() { register_module("a", std::make_shared()); } }; auto out = std::make_shared(); std::stringstream ss; torch::save(out, ss); auto in = std::make_shared(); torch::load(in, ss); const int output = in->named_buffers()["a.c.foo"].sum().item(); ASSERT_EQ(output, 5); } TEST(SerializeTest, VectorOfTensors) { torch::manual_seed(0); std::vector x_vec = { torch::randn({1, 2}), torch::randn({3, 4})}; std::stringstream stream; torch::save(x_vec, stream); std::vector y_vec; torch::load(y_vec, stream); for (const auto i : c10::irange(x_vec.size())) { auto& x = x_vec[i]; auto& y = y_vec[i]; ASSERT_TRUE(y.defined()); ASSERT_EQ(x.sizes().vec(), y.sizes().vec()); ASSERT_TRUE(x.allclose(y)); } } TEST(SerializeTest, IValue) { c10::IValue ivalue(1); auto tempfile = c10::make_tempfile(); torch::serialize::OutputArchive output_archive; output_archive.write("value", ivalue); output_archive.save_to(tempfile.name); torch::serialize::InputArchive input_archive; input_archive.load_from(tempfile.name); c10::IValue ivalue_out; input_archive.read("value", ivalue_out); ASSERT_EQ(ivalue_out.toInt(), 1); ASSERT_THROWS_WITH( input_archive.read("bad_key", ivalue_out), "does not have a field with name"); } // NOTE: if a `Module` contains unserializable submodules (e.g. // `nn::Functional`), we expect those submodules to be skipped when the `Module` // is being serialized. TEST(SerializeTest, UnserializableSubmoduleIsSkippedWhenSavingModule) { struct A : torch::nn::Module { A() { register_module("relu", torch::nn::Functional(torch::relu)); } }; auto out = std::make_shared(); std::stringstream ss; torch::save(out, ss); torch::serialize::InputArchive archive; archive.load_from(ss); torch::serialize::InputArchive relu_archive; // Submodule with name "relu" should not exist in the `InputArchive`, // because the "relu" submodule is an `nn::Functional` and is not // serializable. ASSERT_FALSE(archive.try_read("relu", relu_archive)); } // NOTE: If a `Module` contains unserializable submodules (e.g. // `nn::Functional`), we don't check the existence of those submodules in the // `InputArchive` when deserializing. TEST(SerializeTest, UnserializableSubmoduleIsIgnoredWhenLoadingModule) { struct B : torch::nn::Module { B() { register_module("relu1", torch::nn::Functional(torch::relu)); register_buffer("foo", torch::zeros(5, torch::kInt32)); } }; struct A : torch::nn::Module { A() { register_module("b", std::make_shared()); register_module("relu2", torch::nn::Functional(torch::relu)); } }; auto out = std::make_shared(); // Manually change the values of "b.foo", so that we can check whether the // buffer contains these values after deserialization. out->named_buffers()["b.foo"].fill_(1); auto tempfile = c10::make_tempfile(); torch::save(out, tempfile.name); torch::serialize::InputArchive archive; archive.load_from(tempfile.name); torch::serialize::InputArchive archive_b; torch::serialize::InputArchive archive_relu; torch::Tensor tensor_foo; ASSERT_TRUE(archive.try_read("b", archive_b)); ASSERT_TRUE(archive_b.try_read("foo", tensor_foo, /*is_buffer=*/true)); // Submodule with name "relu1" should not exist in `archive_b`, because the // "relu1" submodule is an `nn::Functional` and is not serializable. ASSERT_FALSE(archive_b.try_read("relu1", archive_relu)); // Submodule with name "relu2" should not exist in `archive`, because the // "relu2" submodule is an `nn::Functional` and is not serializable. ASSERT_FALSE(archive.try_read("relu2", archive_relu)); auto in = std::make_shared(); // `torch::load(...)` works without error, even though `A` contains the // `nn::Functional` submodules while the serialized file doesn't, because the // `nn::Functional` submodules are not serializable and thus ignored when // deserializing. torch::load(in, tempfile.name); // Check that the "b.foo" buffer is correctly deserialized from the file. const int output = in->named_buffers()["b.foo"].sum().item(); // `output` should equal to the sum of the values we manually assigned to // "b.foo" before serialization. ASSERT_EQ(output, 5); }