1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/variadic.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/detail/static.h>
6*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #include <string>
9*da0073e9SAndroid Build Coastguard Worker #include <type_traits>
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker template <
13*da0073e9SAndroid Build Coastguard Worker typename T,
14*da0073e9SAndroid Build Coastguard Worker typename = std::enable_if_t<!torch::detail::is_module<T>::value>>
f(T && m)15*da0073e9SAndroid Build Coastguard Worker bool f(T&& m) {
16*da0073e9SAndroid Build Coastguard Worker return false;
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker template <typename T>
f(T && m)20*da0073e9SAndroid Build Coastguard Worker torch::detail::enable_if_module_t<T, bool> f(T&& m) {
21*da0073e9SAndroid Build Coastguard Worker return true;
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,EnableIfModule)24*da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, EnableIfModule) {
25*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
26*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(f(5));
27*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
28*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
29*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(
30*da0073e9SAndroid Build Coastguard Worker (torch::detail::check_not_lvalue_references<float, int&, char>()));
31*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
32*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
33*da0073e9SAndroid Build Coastguard Worker }
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker namespace {
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker struct A : torch::nn::Module {
forward__anon24ea8ac20111::A38*da0073e9SAndroid Build Coastguard Worker int forward() {
39*da0073e9SAndroid Build Coastguard Worker return 5;
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker };
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker struct B : torch::nn::Module {
forward__anon24ea8ac20111::B44*da0073e9SAndroid Build Coastguard Worker std::string forward(torch::Tensor tensor) {
45*da0073e9SAndroid Build Coastguard Worker return "";
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker };
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker struct C : torch::nn::Module {
forward__anon24ea8ac20111::C50*da0073e9SAndroid Build Coastguard Worker float forward(torch::Tensor& tensor) {
51*da0073e9SAndroid Build Coastguard Worker return 5.0;
52*da0073e9SAndroid Build Coastguard Worker }
53*da0073e9SAndroid Build Coastguard Worker };
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker struct D : torch::nn::Module {
forward__anon24ea8ac20111::D56*da0073e9SAndroid Build Coastguard Worker char forward(torch::Tensor&& tensor) {
57*da0073e9SAndroid Build Coastguard Worker return 'x';
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker };
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker struct E : torch::nn::Module {};
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker // Put in a function because macros don't handle the comma between arguments to
66*da0073e9SAndroid Build Coastguard Worker // is_same well ...
67*da0073e9SAndroid Build Coastguard Worker template <typename Module, typename ExpectedType, typename... Args>
assert_has_expected_type()68*da0073e9SAndroid Build Coastguard Worker void assert_has_expected_type() {
69*da0073e9SAndroid Build Coastguard Worker using ReturnType =
70*da0073e9SAndroid Build Coastguard Worker typename torch::detail::return_type_of_forward<Module, Args...>::type;
71*da0073e9SAndroid Build Coastguard Worker constexpr bool is_expected_type =
72*da0073e9SAndroid Build Coastguard Worker std::is_same<ReturnType, ExpectedType>::value;
73*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(is_expected_type) << Module().name();
74*da0073e9SAndroid Build Coastguard Worker }
75*da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,ReturnTypeOfForward)76*da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, ReturnTypeOfForward) {
77*da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<A, int>();
78*da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<B, std::string, torch::Tensor>();
79*da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<C, float, torch::Tensor&>();
80*da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<D, char, torch::Tensor&&>();
81*da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<E, void>();
82*da0073e9SAndroid Build Coastguard Worker }
83*da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,Apply)84*da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, Apply) {
85*da0073e9SAndroid Build Coastguard Worker std::vector<int> v;
86*da0073e9SAndroid Build Coastguard Worker torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
87*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v.size(), 5);
88*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(v.size())) {
89*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v.at(i), i + 1);
90*da0073e9SAndroid Build Coastguard Worker }
91*da0073e9SAndroid Build Coastguard Worker }
92