#include #include #include #include #include using namespace torch::nn; struct AnyModuleTest : torch::test::SeedingFixture {}; TEST_F(AnyModuleTest, SimpleReturnType) { struct M : torch::nn::Module { int forward() { return 123; } }; AnyModule any(M{}); ASSERT_EQ(any.forward(), 123); } TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) { struct M : torch::nn::Module { int forward(int x) { return x; } }; AnyModule any(M{}); ASSERT_EQ(any.forward(5), 5); } TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) { struct M : torch::nn::Module { const char* forward(const char* x) { return x; } }; AnyModule any(M{}); ASSERT_EQ(any.forward("hello"), std::string("hello")); } TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) { struct M : torch::nn::Module { std::string forward(int x, const double f) { return std::to_string(static_cast(x + f)); } }; AnyModule any(M{}); int x = 4; ASSERT_EQ(any.forward(x, 3.14), std::string("7")); } TEST_F( AnyModuleTest, TensorReturnTypeAndStringArgumentsWithFunkyQualifications) { struct M : torch::nn::Module { torch::Tensor forward( std::string a, const std::string& b, std::string&& c) { const auto s = a + b + c; return torch::ones({static_cast(s.size())}); } }; AnyModule any(M{}); ASSERT_TRUE( any.forward(std::string("a"), std::string("ab"), std::string("abc")) .sum() .item() == 6); } TEST_F(AnyModuleTest, WrongArgumentType) { struct M : torch::nn::Module { int forward(float x) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return x; } }; AnyModule any(M{}); ASSERT_THROWS_WITH( any.forward(5.0), "Expected argument #0 to be of type float, " "but received value of type double"); } struct M_test_wrong_number_of_arguments : torch::nn::Module { int forward(int a, int b) { return a + b; } }; TEST_F(AnyModuleTest, WrongNumberOfArguments) { AnyModule any(M_test_wrong_number_of_arguments{}); #if defined(_MSC_VER) std::string module_name = "struct M_test_wrong_number_of_arguments"; #else std::string module_name = "M_test_wrong_number_of_arguments"; #endif ASSERT_THROWS_WITH( any.forward(), module_name + "'s forward() method expects 2 argument(s), but received 0. " "If " + module_name + "'s forward() method has default arguments, " "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(5), module_name + "'s forward() method expects 2 argument(s), but received 1. " "If " + module_name + "'s forward() method has default arguments, " "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 2, 3), module_name + "'s forward() method expects 2 argument(s), but received 3."); } struct M_default_arg_with_macro : torch::nn::Module { double forward(int a, int b = 2, double c = 3.0) { return a + b + c; } protected: FORWARD_HAS_DEFAULT_ARGS( {1, torch::nn::AnyValue(2)}, {2, torch::nn::AnyValue(3.0)}) }; struct M_default_arg_without_macro : torch::nn::Module { double forward(int a, int b = 2, double c = 3.0) { return a + b + c; } }; TEST_F( AnyModuleTest, PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) { { AnyModule any(M_default_arg_with_macro{}); ASSERT_EQ(any.forward(1), 6.0); ASSERT_EQ(any.forward(1, 3), 7.0); ASSERT_EQ(any.forward(1, 3, 5.0), 9.0); ASSERT_THROWS_WITH( any.forward(), "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0."); ASSERT_THROWS_WITH( any.forward(1, 2, 3.0, 4), "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4."); } { AnyModule any(M_default_arg_without_macro{}); ASSERT_EQ(any.forward(1, 3, 5.0), 9.0); #if defined(_MSC_VER) std::string module_name = "struct M_default_arg_without_macro"; #else std::string module_name = "M_default_arg_without_macro"; #endif ASSERT_THROWS_WITH( any.forward(), module_name + "'s forward() method expects 3 argument(s), but received 0. " "If " + module_name + "'s forward() method has default arguments, " "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1), module_name + "'s forward() method expects 3 argument(s), but received 1. " "If " + module_name + "'s forward() method has default arguments, " "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 3), module_name + "'s forward() method expects 3 argument(s), but received 2. " "If " + module_name + "'s forward() method has default arguments, " "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro."); ASSERT_THROWS_WITH( any.forward(1, 2, 3.0, 4), module_name + "'s forward() method expects 3 argument(s), but received 4."); } } struct M : torch::nn::Module { explicit M(int value_) : torch::nn::Module("M"), value(value_) {} int value; int forward(float x) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return x; } }; TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) { AnyModule any(M{5}); ASSERT_EQ(any.get().value, 5); } TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) { struct N : torch::nn::Module { torch::Tensor forward(torch::Tensor input) { return input; } }; AnyModule any(M{5}); ASSERT_THROWS_WITH(any.get(), "Attempted to cast module"); } TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) { AnyModule any(M{5}); auto ptr = any.ptr(); ASSERT_NE(ptr, nullptr); ASSERT_EQ(ptr->name(), "M"); } TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) { AnyModule any(M{5}); auto ptr = any.ptr(); ASSERT_NE(ptr, nullptr); ASSERT_EQ(ptr->value, 5); } TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) { struct N : torch::nn::Module { torch::Tensor forward(torch::Tensor input) { return input; } }; AnyModule any(M{5}); ASSERT_THROWS_WITH(any.ptr(), "Attempted to cast module"); } TEST_F(AnyModuleTest, DefaultStateIsEmpty) { struct M : torch::nn::Module { explicit M(int value_) : value(value_) {} int value; int forward(float x) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return x; } }; AnyModule any; ASSERT_TRUE(any.is_empty()); any = std::make_shared(5); ASSERT_FALSE(any.is_empty()); ASSERT_EQ(any.get().value, 5); } TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) { struct M : torch::nn::Module { int forward(int x) { return x; } }; AnyModule any; ASSERT_TRUE(any.is_empty()); ASSERT_THROWS_WITH(any.get(), "Cannot call get() on an empty AnyModule"); ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule"); ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule"); ASSERT_THROWS_WITH( any.type_info(), "Cannot call type_info() on an empty AnyModule"); ASSERT_THROWS_WITH( any.forward(5), "Cannot call forward() on an empty AnyModule"); } TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) { struct M : torch::nn::Module { std::string forward(int x) { return std::to_string(x); } }; struct N : torch::nn::Module { int forward(float x) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return 3 + x; } }; AnyModule any; ASSERT_TRUE(any.is_empty()); any = std::make_shared(); ASSERT_FALSE(any.is_empty()); ASSERT_EQ(any.forward(5), "5"); any = std::make_shared(); ASSERT_FALSE(any.is_empty()); ASSERT_EQ(any.forward(5.0f), 8); } TEST_F(AnyModuleTest, ConstructsFromModuleHolder) { struct MImpl : torch::nn::Module { explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {} int value; int forward(float x) { // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) return x; } }; struct M : torch::nn::ModuleHolder { using torch::nn::ModuleHolder::ModuleHolder; using torch::nn::ModuleHolder::get; }; AnyModule any(M{5}); ASSERT_EQ(any.get().value, 5); ASSERT_EQ(any.get()->value, 5); AnyModule module(Linear(3, 4)); std::shared_ptr ptr = module.ptr(); Linear linear(module.get()); } TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) { struct M : torch::nn::Module { torch::Tensor forward(torch::Tensor input) { return input; } }; // When you have an autograd::Variable, it should be converted to a // torch::Tensor before being passed to the function (to avoid a type // mismatch). AnyModule any(M{}); ASSERT_TRUE( any.forward(torch::autograd::Variable(torch::ones(5))) .sum() .item() == 5); // at::Tensors that are not variables work too. ASSERT_EQ(any.forward(at::ones(5)).sum().item(), 5); } namespace torch { namespace nn { struct TestAnyValue { template // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) explicit TestAnyValue(T&& value) : value_(std::forward(value)) {} AnyValue operator()() { return std::move(value_); } AnyValue value_; }; template AnyValue make_value(T&& value) { return TestAnyValue(std::forward(value))(); } } // namespace nn } // namespace torch struct AnyValueTest : torch::test::SeedingFixture {}; TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) { auto value = make_value(5); ASSERT_NE(value.try_get(), nullptr); // const and non-const types have the same typeid(), // but casting Holder to Holder is undefined // behavior according to UBSAN: // https://github.com/pytorch/pytorch/issues/26964 // ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.get(), 5); } // This test does not work at all, because it looks like make_value // decays const int into int. // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) { // auto value = make_value(5); // ASSERT_NE(value.try_get(), nullptr); // // ASSERT_NE(value.try_get(), nullptr); // ASSERT_EQ(value.get(), 5); //} TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) { auto value = make_value("hello"); ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.get(), std::string("hello")); } TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) { auto value = make_value(std::string("hello")); ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.get(), "hello"); } TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) { std::string s("hello"); std::string* p = &s; auto value = make_value(p); ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(*value.get(), "hello"); } TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) { std::string s("hello"); const std::string& t = s; auto value = make_value(t); ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.get(), "hello"); } TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) { auto value = make_value(5); ASSERT_NE(value.try_get(), nullptr); ASSERT_EQ(value.try_get(), nullptr); ASSERT_EQ(value.try_get(), nullptr); ASSERT_EQ(value.try_get(), nullptr); } TEST_F(AnyValueTest, GetThrowsForTheWrongType) { auto value = make_value(5); ASSERT_NE(value.try_get(), nullptr); ASSERT_THROWS_WITH( value.get(), "Attempted to cast AnyValue to float, " "but its actual type is int"); ASSERT_THROWS_WITH( value.get(), "Attempted to cast AnyValue to long, " "but its actual type is int"); } TEST_F(AnyValueTest, MoveConstructionIsAllowed) { auto value = make_value(5); auto copy = make_value(std::move(value)); ASSERT_NE(copy.try_get(), nullptr); ASSERT_EQ(copy.get(), 5); } TEST_F(AnyValueTest, MoveAssignmentIsAllowed) { auto value = make_value(5); auto copy = make_value(10); copy = std::move(value); ASSERT_NE(copy.try_get(), nullptr); ASSERT_EQ(copy.get(), 5); } TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) { auto value = make_value(5); ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code()); } TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) { auto value = make_value("hello"); ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code()); } TEST_F(AnyValueTest, TypeInfoIsCorrectForString) { auto value = make_value(std::string("hello")); ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code()); }