1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/csrc/utils/variadic.h>
5 #include <torch/detail/static.h>
6 #include <torch/torch.h>
7
8 #include <string>
9 #include <type_traits>
10 #include <vector>
11
12 template <
13 typename T,
14 typename = std::enable_if_t<!torch::detail::is_module<T>::value>>
f(T && m)15 bool f(T&& m) {
16 return false;
17 }
18
19 template <typename T>
f(T && m)20 torch::detail::enable_if_module_t<T, bool> f(T&& m) {
21 return true;
22 }
23
TEST(TestStatic,EnableIfModule)24 TEST(TestStatic, EnableIfModule) {
25 ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
26 ASSERT_FALSE(f(5));
27 ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
28 ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
29 ASSERT_FALSE(
30 (torch::detail::check_not_lvalue_references<float, int&, char>()));
31 ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
32 ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
33 }
34
35 namespace {
36
37 struct A : torch::nn::Module {
forward__anon24ea8ac20111::A38 int forward() {
39 return 5;
40 }
41 };
42
43 struct B : torch::nn::Module {
forward__anon24ea8ac20111::B44 std::string forward(torch::Tensor tensor) {
45 return "";
46 }
47 };
48
49 struct C : torch::nn::Module {
forward__anon24ea8ac20111::C50 float forward(torch::Tensor& tensor) {
51 return 5.0;
52 }
53 };
54
55 struct D : torch::nn::Module {
forward__anon24ea8ac20111::D56 char forward(torch::Tensor&& tensor) {
57 return 'x';
58 }
59 };
60
61 struct E : torch::nn::Module {};
62
63 } // anonymous namespace
64
65 // Put in a function because macros don't handle the comma between arguments to
66 // is_same well ...
67 template <typename Module, typename ExpectedType, typename... Args>
assert_has_expected_type()68 void assert_has_expected_type() {
69 using ReturnType =
70 typename torch::detail::return_type_of_forward<Module, Args...>::type;
71 constexpr bool is_expected_type =
72 std::is_same<ReturnType, ExpectedType>::value;
73 ASSERT_TRUE(is_expected_type) << Module().name();
74 }
75
TEST(TestStatic,ReturnTypeOfForward)76 TEST(TestStatic, ReturnTypeOfForward) {
77 assert_has_expected_type<A, int>();
78 assert_has_expected_type<B, std::string, torch::Tensor>();
79 assert_has_expected_type<C, float, torch::Tensor&>();
80 assert_has_expected_type<D, char, torch::Tensor&&>();
81 assert_has_expected_type<E, void>();
82 }
83
TEST(TestStatic,Apply)84 TEST(TestStatic, Apply) {
85 std::vector<int> v;
86 torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
87 ASSERT_EQ(v.size(), 5);
88 for (const auto i : c10::irange(v.size())) {
89 ASSERT_EQ(v.at(i), i + 1);
90 }
91 }
92