xref: /aosp_15_r20/external/pytorch/test/cpp/api/any.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
8*da0073e9SAndroid Build Coastguard Worker #include <string>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct AnyModuleTest : torch::test::SeedingFixture {};
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,SimpleReturnType)14*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, SimpleReturnType) {
15*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
16*da0073e9SAndroid Build Coastguard Worker     int forward() {
17*da0073e9SAndroid Build Coastguard Worker       return 123;
18*da0073e9SAndroid Build Coastguard Worker     }
19*da0073e9SAndroid Build Coastguard Worker   };
20*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
21*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<int>(), 123);
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,SimpleReturnTypeAndSingleArgument)24*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
25*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
26*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
27*da0073e9SAndroid Build Coastguard Worker       return x;
28*da0073e9SAndroid Build Coastguard Worker     }
29*da0073e9SAndroid Build Coastguard Worker   };
30*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
31*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<int>(5), 5);
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,StringLiteralReturnTypeAndArgument)34*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
35*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
36*da0073e9SAndroid Build Coastguard Worker     const char* forward(const char* x) {
37*da0073e9SAndroid Build Coastguard Worker       return x;
38*da0073e9SAndroid Build Coastguard Worker     }
39*da0073e9SAndroid Build Coastguard Worker   };
40*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
41*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
42*da0073e9SAndroid Build Coastguard Worker }
43*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,StringReturnTypeWithConstArgument)44*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
45*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
46*da0073e9SAndroid Build Coastguard Worker     std::string forward(int x, const double f) {
47*da0073e9SAndroid Build Coastguard Worker       return std::to_string(static_cast<int>(x + f));
48*da0073e9SAndroid Build Coastguard Worker     }
49*da0073e9SAndroid Build Coastguard Worker   };
50*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
51*da0073e9SAndroid Build Coastguard Worker   int x = 4;
52*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
53*da0073e9SAndroid Build Coastguard Worker }
54*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,TensorReturnTypeAndStringArgumentsWithFunkyQualifications)55*da0073e9SAndroid Build Coastguard Worker TEST_F(
56*da0073e9SAndroid Build Coastguard Worker     AnyModuleTest,
57*da0073e9SAndroid Build Coastguard Worker     TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
58*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
59*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(
60*da0073e9SAndroid Build Coastguard Worker         std::string a,
61*da0073e9SAndroid Build Coastguard Worker         const std::string& b,
62*da0073e9SAndroid Build Coastguard Worker         std::string&& c) {
63*da0073e9SAndroid Build Coastguard Worker       const auto s = a + b + c;
64*da0073e9SAndroid Build Coastguard Worker       return torch::ones({static_cast<int64_t>(s.size())});
65*da0073e9SAndroid Build Coastguard Worker     }
66*da0073e9SAndroid Build Coastguard Worker   };
67*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
68*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
69*da0073e9SAndroid Build Coastguard Worker       any.forward(std::string("a"), std::string("ab"), std::string("abc"))
70*da0073e9SAndroid Build Coastguard Worker           .sum()
71*da0073e9SAndroid Build Coastguard Worker           .item<int32_t>() == 6);
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,WrongArgumentType)74*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, WrongArgumentType) {
75*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
76*da0073e9SAndroid Build Coastguard Worker     int forward(float x) {
77*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
78*da0073e9SAndroid Build Coastguard Worker       return x;
79*da0073e9SAndroid Build Coastguard Worker     }
80*da0073e9SAndroid Build Coastguard Worker   };
81*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
82*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
83*da0073e9SAndroid Build Coastguard Worker       any.forward(5.0),
84*da0073e9SAndroid Build Coastguard Worker       "Expected argument #0 to be of type float, "
85*da0073e9SAndroid Build Coastguard Worker       "but received value of type double");
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker struct M_test_wrong_number_of_arguments : torch::nn::Module {
forwardM_test_wrong_number_of_arguments89*da0073e9SAndroid Build Coastguard Worker   int forward(int a, int b) {
90*da0073e9SAndroid Build Coastguard Worker     return a + b;
91*da0073e9SAndroid Build Coastguard Worker   }
92*da0073e9SAndroid Build Coastguard Worker };
93*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,WrongNumberOfArguments)94*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, WrongNumberOfArguments) {
95*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M_test_wrong_number_of_arguments{});
96*da0073e9SAndroid Build Coastguard Worker #if defined(_MSC_VER)
97*da0073e9SAndroid Build Coastguard Worker   std::string module_name = "struct M_test_wrong_number_of_arguments";
98*da0073e9SAndroid Build Coastguard Worker #else
99*da0073e9SAndroid Build Coastguard Worker   std::string module_name = "M_test_wrong_number_of_arguments";
100*da0073e9SAndroid Build Coastguard Worker #endif
101*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
102*da0073e9SAndroid Build Coastguard Worker       any.forward(),
103*da0073e9SAndroid Build Coastguard Worker       module_name +
104*da0073e9SAndroid Build Coastguard Worker           "'s forward() method expects 2 argument(s), but received 0. "
105*da0073e9SAndroid Build Coastguard Worker           "If " +
106*da0073e9SAndroid Build Coastguard Worker           module_name +
107*da0073e9SAndroid Build Coastguard Worker           "'s forward() method has default arguments, "
108*da0073e9SAndroid Build Coastguard Worker           "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
109*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
110*da0073e9SAndroid Build Coastguard Worker       any.forward(5),
111*da0073e9SAndroid Build Coastguard Worker       module_name +
112*da0073e9SAndroid Build Coastguard Worker           "'s forward() method expects 2 argument(s), but received 1. "
113*da0073e9SAndroid Build Coastguard Worker           "If " +
114*da0073e9SAndroid Build Coastguard Worker           module_name +
115*da0073e9SAndroid Build Coastguard Worker           "'s forward() method has default arguments, "
116*da0073e9SAndroid Build Coastguard Worker           "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
117*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
118*da0073e9SAndroid Build Coastguard Worker       any.forward(1, 2, 3),
119*da0073e9SAndroid Build Coastguard Worker       module_name +
120*da0073e9SAndroid Build Coastguard Worker           "'s forward() method expects 2 argument(s), but received 3.");
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker 
123*da0073e9SAndroid Build Coastguard Worker struct M_default_arg_with_macro : torch::nn::Module {
forwardM_default_arg_with_macro124*da0073e9SAndroid Build Coastguard Worker   double forward(int a, int b = 2, double c = 3.0) {
125*da0073e9SAndroid Build Coastguard Worker     return a + b + c;
126*da0073e9SAndroid Build Coastguard Worker   }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker  protected:
129*da0073e9SAndroid Build Coastguard Worker   FORWARD_HAS_DEFAULT_ARGS(
130*da0073e9SAndroid Build Coastguard Worker       {1, torch::nn::AnyValue(2)},
131*da0073e9SAndroid Build Coastguard Worker       {2, torch::nn::AnyValue(3.0)})
132*da0073e9SAndroid Build Coastguard Worker };
133*da0073e9SAndroid Build Coastguard Worker 
134*da0073e9SAndroid Build Coastguard Worker struct M_default_arg_without_macro : torch::nn::Module {
forwardM_default_arg_without_macro135*da0073e9SAndroid Build Coastguard Worker   double forward(int a, int b = 2, double c = 3.0) {
136*da0073e9SAndroid Build Coastguard Worker     return a + b + c;
137*da0073e9SAndroid Build Coastguard Worker   }
138*da0073e9SAndroid Build Coastguard Worker };
139*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod)140*da0073e9SAndroid Build Coastguard Worker TEST_F(
141*da0073e9SAndroid Build Coastguard Worker     AnyModuleTest,
142*da0073e9SAndroid Build Coastguard Worker     PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
143*da0073e9SAndroid Build Coastguard Worker   {
144*da0073e9SAndroid Build Coastguard Worker     AnyModule any(M_default_arg_with_macro{});
145*da0073e9SAndroid Build Coastguard Worker 
146*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(any.forward<double>(1), 6.0);
147*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(any.forward<double>(1, 3), 7.0);
148*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
149*da0073e9SAndroid Build Coastguard Worker 
150*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
151*da0073e9SAndroid Build Coastguard Worker         any.forward(),
152*da0073e9SAndroid Build Coastguard Worker         "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
153*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
154*da0073e9SAndroid Build Coastguard Worker         any.forward(1, 2, 3.0, 4),
155*da0073e9SAndroid Build Coastguard Worker         "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
156*da0073e9SAndroid Build Coastguard Worker   }
157*da0073e9SAndroid Build Coastguard Worker   {
158*da0073e9SAndroid Build Coastguard Worker     AnyModule any(M_default_arg_without_macro{});
159*da0073e9SAndroid Build Coastguard Worker 
160*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
161*da0073e9SAndroid Build Coastguard Worker 
162*da0073e9SAndroid Build Coastguard Worker #if defined(_MSC_VER)
163*da0073e9SAndroid Build Coastguard Worker     std::string module_name = "struct M_default_arg_without_macro";
164*da0073e9SAndroid Build Coastguard Worker #else
165*da0073e9SAndroid Build Coastguard Worker     std::string module_name = "M_default_arg_without_macro";
166*da0073e9SAndroid Build Coastguard Worker #endif
167*da0073e9SAndroid Build Coastguard Worker 
168*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
169*da0073e9SAndroid Build Coastguard Worker         any.forward(),
170*da0073e9SAndroid Build Coastguard Worker         module_name +
171*da0073e9SAndroid Build Coastguard Worker             "'s forward() method expects 3 argument(s), but received 0. "
172*da0073e9SAndroid Build Coastguard Worker             "If " +
173*da0073e9SAndroid Build Coastguard Worker             module_name +
174*da0073e9SAndroid Build Coastguard Worker             "'s forward() method has default arguments, "
175*da0073e9SAndroid Build Coastguard Worker             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
176*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
177*da0073e9SAndroid Build Coastguard Worker         any.forward<double>(1),
178*da0073e9SAndroid Build Coastguard Worker         module_name +
179*da0073e9SAndroid Build Coastguard Worker             "'s forward() method expects 3 argument(s), but received 1. "
180*da0073e9SAndroid Build Coastguard Worker             "If " +
181*da0073e9SAndroid Build Coastguard Worker             module_name +
182*da0073e9SAndroid Build Coastguard Worker             "'s forward() method has default arguments, "
183*da0073e9SAndroid Build Coastguard Worker             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
184*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
185*da0073e9SAndroid Build Coastguard Worker         any.forward<double>(1, 3),
186*da0073e9SAndroid Build Coastguard Worker         module_name +
187*da0073e9SAndroid Build Coastguard Worker             "'s forward() method expects 3 argument(s), but received 2. "
188*da0073e9SAndroid Build Coastguard Worker             "If " +
189*da0073e9SAndroid Build Coastguard Worker             module_name +
190*da0073e9SAndroid Build Coastguard Worker             "'s forward() method has default arguments, "
191*da0073e9SAndroid Build Coastguard Worker             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
192*da0073e9SAndroid Build Coastguard Worker     ASSERT_THROWS_WITH(
193*da0073e9SAndroid Build Coastguard Worker         any.forward(1, 2, 3.0, 4),
194*da0073e9SAndroid Build Coastguard Worker         module_name +
195*da0073e9SAndroid Build Coastguard Worker             "'s forward() method expects 3 argument(s), but received 4.");
196*da0073e9SAndroid Build Coastguard Worker   }
197*da0073e9SAndroid Build Coastguard Worker }
198*da0073e9SAndroid Build Coastguard Worker 
199*da0073e9SAndroid Build Coastguard Worker struct M : torch::nn::Module {
MM200*da0073e9SAndroid Build Coastguard Worker   explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
201*da0073e9SAndroid Build Coastguard Worker   int value;
forwardM202*da0073e9SAndroid Build Coastguard Worker   int forward(float x) {
203*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
204*da0073e9SAndroid Build Coastguard Worker     return x;
205*da0073e9SAndroid Build Coastguard Worker   }
206*da0073e9SAndroid Build Coastguard Worker };
207*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,GetWithCorrectTypeSucceeds)208*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
209*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
210*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.get<M>().value, 5);
211*da0073e9SAndroid Build Coastguard Worker }
212*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,GetWithIncorrectTypeThrows)213*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
214*da0073e9SAndroid Build Coastguard Worker   struct N : torch::nn::Module {
215*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
216*da0073e9SAndroid Build Coastguard Worker       return input;
217*da0073e9SAndroid Build Coastguard Worker     }
218*da0073e9SAndroid Build Coastguard Worker   };
219*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
220*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,PtrWithBaseClassSucceeds)223*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
224*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
225*da0073e9SAndroid Build Coastguard Worker   auto ptr = any.ptr();
226*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(ptr, nullptr);
227*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(ptr->name(), "M");
228*da0073e9SAndroid Build Coastguard Worker }
229*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,PtrWithGoodDowncastSuccceeds)230*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
231*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
232*da0073e9SAndroid Build Coastguard Worker   auto ptr = any.ptr<M>();
233*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(ptr, nullptr);
234*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(ptr->value, 5);
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,PtrWithBadDowncastThrows)237*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
238*da0073e9SAndroid Build Coastguard Worker   struct N : torch::nn::Module {
239*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
240*da0073e9SAndroid Build Coastguard Worker       return input;
241*da0073e9SAndroid Build Coastguard Worker     }
242*da0073e9SAndroid Build Coastguard Worker   };
243*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
244*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,DefaultStateIsEmpty)247*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
248*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
249*da0073e9SAndroid Build Coastguard Worker     explicit M(int value_) : value(value_) {}
250*da0073e9SAndroid Build Coastguard Worker     int value;
251*da0073e9SAndroid Build Coastguard Worker     int forward(float x) {
252*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
253*da0073e9SAndroid Build Coastguard Worker       return x;
254*da0073e9SAndroid Build Coastguard Worker     }
255*da0073e9SAndroid Build Coastguard Worker   };
256*da0073e9SAndroid Build Coastguard Worker   AnyModule any;
257*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(any.is_empty());
258*da0073e9SAndroid Build Coastguard Worker   any = std::make_shared<M>(5);
259*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(any.is_empty());
260*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.get<M>().value, 5);
261*da0073e9SAndroid Build Coastguard Worker }
262*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,AllMethodsThrowForEmptyAnyModule)263*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
264*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
265*da0073e9SAndroid Build Coastguard Worker     int forward(int x) {
266*da0073e9SAndroid Build Coastguard Worker       return x;
267*da0073e9SAndroid Build Coastguard Worker     }
268*da0073e9SAndroid Build Coastguard Worker   };
269*da0073e9SAndroid Build Coastguard Worker   AnyModule any;
270*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(any.is_empty());
271*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
272*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
273*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
274*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
275*da0073e9SAndroid Build Coastguard Worker       any.type_info(), "Cannot call type_info() on an empty AnyModule");
276*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
277*da0073e9SAndroid Build Coastguard Worker       any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
278*da0073e9SAndroid Build Coastguard Worker }
279*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,CanMoveAssignDifferentModules)280*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
281*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
282*da0073e9SAndroid Build Coastguard Worker     std::string forward(int x) {
283*da0073e9SAndroid Build Coastguard Worker       return std::to_string(x);
284*da0073e9SAndroid Build Coastguard Worker     }
285*da0073e9SAndroid Build Coastguard Worker   };
286*da0073e9SAndroid Build Coastguard Worker   struct N : torch::nn::Module {
287*da0073e9SAndroid Build Coastguard Worker     int forward(float x) {
288*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
289*da0073e9SAndroid Build Coastguard Worker       return 3 + x;
290*da0073e9SAndroid Build Coastguard Worker     }
291*da0073e9SAndroid Build Coastguard Worker   };
292*da0073e9SAndroid Build Coastguard Worker   AnyModule any;
293*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(any.is_empty());
294*da0073e9SAndroid Build Coastguard Worker   any = std::make_shared<M>();
295*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(any.is_empty());
296*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<std::string>(5), "5");
297*da0073e9SAndroid Build Coastguard Worker   any = std::make_shared<N>();
298*da0073e9SAndroid Build Coastguard Worker   ASSERT_FALSE(any.is_empty());
299*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward<int>(5.0f), 8);
300*da0073e9SAndroid Build Coastguard Worker }
301*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,ConstructsFromModuleHolder)302*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
303*da0073e9SAndroid Build Coastguard Worker   struct MImpl : torch::nn::Module {
304*da0073e9SAndroid Build Coastguard Worker     explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
305*da0073e9SAndroid Build Coastguard Worker     int value;
306*da0073e9SAndroid Build Coastguard Worker     int forward(float x) {
307*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
308*da0073e9SAndroid Build Coastguard Worker       return x;
309*da0073e9SAndroid Build Coastguard Worker     }
310*da0073e9SAndroid Build Coastguard Worker   };
311*da0073e9SAndroid Build Coastguard Worker 
312*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::ModuleHolder<MImpl> {
313*da0073e9SAndroid Build Coastguard Worker     using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
314*da0073e9SAndroid Build Coastguard Worker     using torch::nn::ModuleHolder<MImpl>::get;
315*da0073e9SAndroid Build Coastguard Worker   };
316*da0073e9SAndroid Build Coastguard Worker 
317*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{5});
318*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.get<MImpl>().value, 5);
319*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.get<M>()->value, 5);
320*da0073e9SAndroid Build Coastguard Worker 
321*da0073e9SAndroid Build Coastguard Worker   AnyModule module(Linear(3, 4));
322*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<Module> ptr = module.ptr();
323*da0073e9SAndroid Build Coastguard Worker   Linear linear(module.get<Linear>());
324*da0073e9SAndroid Build Coastguard Worker }
325*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyModuleTest,ConvertsVariableToTensorCorrectly)326*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
327*da0073e9SAndroid Build Coastguard Worker   struct M : torch::nn::Module {
328*da0073e9SAndroid Build Coastguard Worker     torch::Tensor forward(torch::Tensor input) {
329*da0073e9SAndroid Build Coastguard Worker       return input;
330*da0073e9SAndroid Build Coastguard Worker     }
331*da0073e9SAndroid Build Coastguard Worker   };
332*da0073e9SAndroid Build Coastguard Worker 
333*da0073e9SAndroid Build Coastguard Worker   // When you have an autograd::Variable, it should be converted to a
334*da0073e9SAndroid Build Coastguard Worker   // torch::Tensor before being passed to the function (to avoid a type
335*da0073e9SAndroid Build Coastguard Worker   // mismatch).
336*da0073e9SAndroid Build Coastguard Worker   AnyModule any(M{});
337*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(
338*da0073e9SAndroid Build Coastguard Worker       any.forward(torch::autograd::Variable(torch::ones(5)))
339*da0073e9SAndroid Build Coastguard Worker           .sum()
340*da0073e9SAndroid Build Coastguard Worker           .item<float>() == 5);
341*da0073e9SAndroid Build Coastguard Worker   // at::Tensors that are not variables work too.
342*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker 
345*da0073e9SAndroid Build Coastguard Worker namespace torch {
346*da0073e9SAndroid Build Coastguard Worker namespace nn {
347*da0073e9SAndroid Build Coastguard Worker struct TestAnyValue {
348*da0073e9SAndroid Build Coastguard Worker   template <typename T>
349*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
TestAnyValuetorch::nn::TestAnyValue350*da0073e9SAndroid Build Coastguard Worker   explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
operator ()torch::nn::TestAnyValue351*da0073e9SAndroid Build Coastguard Worker   AnyValue operator()() {
352*da0073e9SAndroid Build Coastguard Worker     return std::move(value_);
353*da0073e9SAndroid Build Coastguard Worker   }
354*da0073e9SAndroid Build Coastguard Worker   AnyValue value_;
355*da0073e9SAndroid Build Coastguard Worker };
356*da0073e9SAndroid Build Coastguard Worker template <typename T>
make_value(T && value)357*da0073e9SAndroid Build Coastguard Worker AnyValue make_value(T&& value) {
358*da0073e9SAndroid Build Coastguard Worker   return TestAnyValue(std::forward<T>(value))();
359*da0073e9SAndroid Build Coastguard Worker }
360*da0073e9SAndroid Build Coastguard Worker } // namespace nn
361*da0073e9SAndroid Build Coastguard Worker } // namespace torch
362*da0073e9SAndroid Build Coastguard Worker 
363*da0073e9SAndroid Build Coastguard Worker struct AnyValueTest : torch::test::SeedingFixture {};
364*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,CorrectlyAccessesIntWhenCorrectType)365*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
366*da0073e9SAndroid Build Coastguard Worker   auto value = make_value<int>(5);
367*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<int>(), nullptr);
368*da0073e9SAndroid Build Coastguard Worker   // const and non-const types have the same typeid(),
369*da0073e9SAndroid Build Coastguard Worker   // but casting Holder<int> to Holder<const int> is undefined
370*da0073e9SAndroid Build Coastguard Worker   // behavior according to UBSAN:
371*da0073e9SAndroid Build Coastguard Worker   // https://github.com/pytorch/pytorch/issues/26964
372*da0073e9SAndroid Build Coastguard Worker   // ASSERT_NE(value.try_get<const int>(), nullptr);
373*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.get<int>(), 5);
374*da0073e9SAndroid Build Coastguard Worker }
375*da0073e9SAndroid Build Coastguard Worker // This test does not work at all, because it looks like make_value
376*da0073e9SAndroid Build Coastguard Worker // decays const int into int.
377*da0073e9SAndroid Build Coastguard Worker // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
378*da0073e9SAndroid Build Coastguard Worker //  auto value = make_value<const int>(5);
379*da0073e9SAndroid Build Coastguard Worker //  ASSERT_NE(value.try_get<const int>(), nullptr);
380*da0073e9SAndroid Build Coastguard Worker //  // ASSERT_NE(value.try_get<int>(), nullptr);
381*da0073e9SAndroid Build Coastguard Worker //  ASSERT_EQ(value.get<const int>(), 5);
382*da0073e9SAndroid Build Coastguard Worker //}
TEST_F(AnyValueTest,CorrectlyAccessesStringLiteralWhenCorrectType)383*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
384*da0073e9SAndroid Build Coastguard Worker   auto value = make_value("hello");
385*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<const char*>(), nullptr);
386*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.get<const char*>(), std::string("hello"));
387*da0073e9SAndroid Build Coastguard Worker }
TEST_F(AnyValueTest,CorrectlyAccessesStringWhenCorrectType)388*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
389*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(std::string("hello"));
390*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<std::string>(), nullptr);
391*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.get<std::string>(), "hello");
392*da0073e9SAndroid Build Coastguard Worker }
TEST_F(AnyValueTest,CorrectlyAccessesPointersWhenCorrectType)393*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
394*da0073e9SAndroid Build Coastguard Worker   std::string s("hello");
395*da0073e9SAndroid Build Coastguard Worker   std::string* p = &s;
396*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(p);
397*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<std::string*>(), nullptr);
398*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(*value.get<std::string*>(), "hello");
399*da0073e9SAndroid Build Coastguard Worker }
TEST_F(AnyValueTest,CorrectlyAccessesReferencesWhenCorrectType)400*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
401*da0073e9SAndroid Build Coastguard Worker   std::string s("hello");
402*da0073e9SAndroid Build Coastguard Worker   const std::string& t = s;
403*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(t);
404*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<std::string>(), nullptr);
405*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.get<std::string>(), "hello");
406*da0073e9SAndroid Build Coastguard Worker }
407*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,TryGetReturnsNullptrForTheWrongType)408*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
409*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(5);
410*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<int>(), nullptr);
411*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.try_get<float>(), nullptr);
412*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.try_get<long>(), nullptr);
413*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.try_get<std::string>(), nullptr);
414*da0073e9SAndroid Build Coastguard Worker }
415*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,GetThrowsForTheWrongType)416*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
417*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(5);
418*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(value.try_get<int>(), nullptr);
419*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
420*da0073e9SAndroid Build Coastguard Worker       value.get<float>(),
421*da0073e9SAndroid Build Coastguard Worker       "Attempted to cast AnyValue to float, "
422*da0073e9SAndroid Build Coastguard Worker       "but its actual type is int");
423*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
424*da0073e9SAndroid Build Coastguard Worker       value.get<long>(),
425*da0073e9SAndroid Build Coastguard Worker       "Attempted to cast AnyValue to long, "
426*da0073e9SAndroid Build Coastguard Worker       "but its actual type is int");
427*da0073e9SAndroid Build Coastguard Worker }
428*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,MoveConstructionIsAllowed)429*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
430*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(5);
431*da0073e9SAndroid Build Coastguard Worker   auto copy = make_value(std::move(value));
432*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(copy.try_get<int>(), nullptr);
433*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(copy.get<int>(), 5);
434*da0073e9SAndroid Build Coastguard Worker }
435*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,MoveAssignmentIsAllowed)436*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
437*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(5);
438*da0073e9SAndroid Build Coastguard Worker   auto copy = make_value(10);
439*da0073e9SAndroid Build Coastguard Worker   copy = std::move(value);
440*da0073e9SAndroid Build Coastguard Worker   ASSERT_NE(copy.try_get<int>(), nullptr);
441*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(copy.get<int>(), 5);
442*da0073e9SAndroid Build Coastguard Worker }
443*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,TypeInfoIsCorrectForInt)444*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
445*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(5);
446*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
447*da0073e9SAndroid Build Coastguard Worker }
448*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,TypeInfoIsCorrectForStringLiteral)449*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
450*da0073e9SAndroid Build Coastguard Worker   auto value = make_value("hello");
451*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
452*da0073e9SAndroid Build Coastguard Worker }
453*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AnyValueTest,TypeInfoIsCorrectForString)454*da0073e9SAndroid Build Coastguard Worker TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
455*da0073e9SAndroid Build Coastguard Worker   auto value = make_value(std::string("hello"));
456*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
457*da0073e9SAndroid Build Coastguard Worker }
458