#include #include #include #include #include #include using namespace torch::nn; using namespace torch::test; struct ParameterDictTest : torch::test::SeedingFixture {}; TEST_F(ParameterDictTest, ConstructFromTensor) { ParameterDict dict; torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); ASSERT_TRUE(ta.requires_grad()); ASSERT_FALSE(tb.requires_grad()); dict->insert("A", ta); dict->insert("B", tb); dict->insert("C", tc); ASSERT_EQ(dict->size(), 3); ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item()); ASSERT_TRUE(dict["A"].requires_grad()); ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item()); ASSERT_FALSE(dict["B"].requires_grad()); } TEST_F(ParameterDictTest, ConstructFromOrderedDict) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::OrderedDict params = { {"A", ta}, {"B", tb}, {"C", tc}}; auto dict = torch::nn::ParameterDict(params); ASSERT_EQ(dict->size(), 3); ASSERT_TRUE(torch::all(torch::eq(dict["A"], ta)).item()); ASSERT_TRUE(dict["A"].requires_grad()); ASSERT_TRUE(torch::all(torch::eq(dict["B"], tb)).item()); ASSERT_FALSE(dict["B"].requires_grad()); } TEST_F(ParameterDictTest, InsertAndContains) { ParameterDict dict; dict->insert("A", torch::tensor({1.0})); ASSERT_EQ(dict->size(), 1); ASSERT_TRUE(dict->contains("A")); ASSERT_FALSE(dict->contains("C")); } TEST_F(ParameterDictTest, InsertAndClear) { ParameterDict dict; dict->insert("A", torch::tensor({1.0})); ASSERT_EQ(dict->size(), 1); dict->clear(); ASSERT_EQ(dict->size(), 0); } TEST_F(ParameterDictTest, InsertAndPop) { ParameterDict dict; dict->insert("A", torch::tensor({1.0})); ASSERT_EQ(dict->size(), 1); ASSERT_THROWS_WITH(dict->pop("B"), "Parameter 'B' is not defined"); torch::Tensor p = dict->pop("A"); ASSERT_EQ(dict->size(), 0); ASSERT_TRUE(torch::eq(p, torch::tensor({1.0})).item()); } TEST_F(ParameterDictTest, SimpleUpdate) { ParameterDict dict; ParameterDict wrongDict; ParameterDict rightDict; dict->insert("A", torch::tensor({1.0})); dict->insert("B", torch::tensor({2.0})); dict->insert("C", torch::tensor({3.0})); wrongDict->insert("A", torch::tensor({5.0})); wrongDict->insert("D", torch::tensor({5.0})); ASSERT_THROWS_WITH(dict->update(*wrongDict), "Parameter 'D' is not defined"); rightDict->insert("A", torch::tensor({5.0})); dict->update(*rightDict); ASSERT_EQ(dict->size(), 3); ASSERT_TRUE(torch::eq(dict["A"], torch::tensor({5.0})).item()); } TEST_F(ParameterDictTest, Keys) { torch::OrderedDict params = { {"a", torch::tensor({1.0})}, {"b", torch::tensor({2.0})}, {"c", torch::tensor({1.0, 2.0})}}; auto dict = torch::nn::ParameterDict(params); std::vector keys = dict->keys(); std::vector true_keys{"a", "b", "c"}; ASSERT_EQ(keys, true_keys); } TEST_F(ParameterDictTest, Values) { torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); torch::OrderedDict params = { {"a", ta}, {"b", tb}, {"c", tc}}; auto dict = torch::nn::ParameterDict(params); std::vector values = dict->values(); std::vector true_values{ta, tb, tc}; for (auto i = 0U; i < values.size(); i += 1) { ASSERT_TRUE(torch::all(torch::eq(values[i], true_values[i])).item()); } } TEST_F(ParameterDictTest, Get) { ParameterDict dict; torch::Tensor ta = torch::randn({1, 2}, torch::requires_grad(true)); torch::Tensor tb = torch::randn({1, 2}, torch::requires_grad(false)); torch::Tensor tc = torch::randn({1, 2}); ASSERT_TRUE(ta.requires_grad()); ASSERT_FALSE(tb.requires_grad()); dict->insert("A", ta); dict->insert("B", tb); dict->insert("C", tc); ASSERT_EQ(dict->size(), 3); ASSERT_TRUE(torch::all(torch::eq(dict->get("A"), ta)).item()); ASSERT_TRUE(dict->get("A").requires_grad()); ASSERT_TRUE(torch::all(torch::eq(dict->get("B"), tb)).item()); ASSERT_FALSE(dict->get("B").requires_grad()); } TEST_F(ParameterDictTest, PrettyPrintParameterDict) { torch::OrderedDict params = { {"a", torch::tensor({1.0})}, {"b", torch::tensor({2.0, 1.0})}, {"c", torch::tensor({{3.0}, {2.1}})}, {"d", torch::tensor({{3.0, 1.3}, {1.2, 2.1}})}}; auto dict = torch::nn::ParameterDict(params); ASSERT_EQ( c10::str(dict), "torch::nn::ParameterDict(\n" "(a): Parameter containing: [Float of size [1]]\n" "(b): Parameter containing: [Float of size [2]]\n" "(c): Parameter containing: [Float of size [2, 1]]\n" "(d): Parameter containing: [Float of size [2, 2]]\n" ")"); }