#include #include #include #include #include #include #include #include template < typename T, typename = std::enable_if_t::value>> bool f(T&& m) { return false; } template torch::detail::enable_if_module_t f(T&& m) { return true; } TEST(TestStatic, EnableIfModule) { ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2))); ASSERT_FALSE(f(5)); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_TRUE((torch::detail::check_not_lvalue_references())); ASSERT_FALSE( (torch::detail::check_not_lvalue_references())); ASSERT_TRUE(torch::detail::check_not_lvalue_references()); ASSERT_FALSE(torch::detail::check_not_lvalue_references()); } namespace { struct A : torch::nn::Module { int forward() { return 5; } }; struct B : torch::nn::Module { std::string forward(torch::Tensor tensor) { return ""; } }; struct C : torch::nn::Module { float forward(torch::Tensor& tensor) { return 5.0; } }; struct D : torch::nn::Module { char forward(torch::Tensor&& tensor) { return 'x'; } }; struct E : torch::nn::Module {}; } // anonymous namespace // Put in a function because macros don't handle the comma between arguments to // is_same well ... template void assert_has_expected_type() { using ReturnType = typename torch::detail::return_type_of_forward::type; constexpr bool is_expected_type = std::is_same::value; ASSERT_TRUE(is_expected_type) << Module().name(); } TEST(TestStatic, ReturnTypeOfForward) { assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); assert_has_expected_type(); } TEST(TestStatic, Apply) { std::vector v; torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5); ASSERT_EQ(v.size(), 5); for (const auto i : c10::irange(v.size())) { ASSERT_EQ(v.at(i), i + 1); } }