xref: /aosp_15_r20/external/pytorch/test/cpp/api/static.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/utils/variadic.h>
5 #include <torch/detail/static.h>
6 #include <torch/torch.h>
7 
8 #include <string>
9 #include <type_traits>
10 #include <vector>
11 
12 template <
13     typename T,
14     typename = std::enable_if_t<!torch::detail::is_module<T>::value>>
f(T && m)15 bool f(T&& m) {
16   return false;
17 }
18 
19 template <typename T>
f(T && m)20 torch::detail::enable_if_module_t<T, bool> f(T&& m) {
21   return true;
22 }
23 
TEST(TestStatic,EnableIfModule)24 TEST(TestStatic, EnableIfModule) {
25   ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
26   ASSERT_FALSE(f(5));
27   ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
28   ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
29   ASSERT_FALSE(
30       (torch::detail::check_not_lvalue_references<float, int&, char>()));
31   ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
32   ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
33 }
34 
35 namespace {
36 
37 struct A : torch::nn::Module {
forward__anon24ea8ac20111::A38   int forward() {
39     return 5;
40   }
41 };
42 
43 struct B : torch::nn::Module {
forward__anon24ea8ac20111::B44   std::string forward(torch::Tensor tensor) {
45     return "";
46   }
47 };
48 
49 struct C : torch::nn::Module {
forward__anon24ea8ac20111::C50   float forward(torch::Tensor& tensor) {
51     return 5.0;
52   }
53 };
54 
55 struct D : torch::nn::Module {
forward__anon24ea8ac20111::D56   char forward(torch::Tensor&& tensor) {
57     return 'x';
58   }
59 };
60 
61 struct E : torch::nn::Module {};
62 
63 } // anonymous namespace
64 
65 // Put in a function because macros don't handle the comma between arguments to
66 // is_same well ...
67 template <typename Module, typename ExpectedType, typename... Args>
assert_has_expected_type()68 void assert_has_expected_type() {
69   using ReturnType =
70       typename torch::detail::return_type_of_forward<Module, Args...>::type;
71   constexpr bool is_expected_type =
72       std::is_same<ReturnType, ExpectedType>::value;
73   ASSERT_TRUE(is_expected_type) << Module().name();
74 }
75 
TEST(TestStatic,ReturnTypeOfForward)76 TEST(TestStatic, ReturnTypeOfForward) {
77   assert_has_expected_type<A, int>();
78   assert_has_expected_type<B, std::string, torch::Tensor>();
79   assert_has_expected_type<C, float, torch::Tensor&>();
80   assert_has_expected_type<D, char, torch::Tensor&&>();
81   assert_has_expected_type<E, void>();
82 }
83 
TEST(TestStatic,Apply)84 TEST(TestStatic, Apply) {
85   std::vector<int> v;
86   torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
87   ASSERT_EQ(v.size(), 5);
88   for (const auto i : c10::irange(v.size())) {
89     ASSERT_EQ(v.at(i), i + 1);
90   }
91 }
92