1 #include <gtest/gtest.h>
2
3 #include <torch/torch.h>
4
5 #include <test/cpp/api/support.h>
6
7 #include <algorithm>
8 #include <string>
9
10 using namespace torch::nn;
11
12 struct AnyModuleTest : torch::test::SeedingFixture {};
13
TEST_F(AnyModuleTest,SimpleReturnType)14 TEST_F(AnyModuleTest, SimpleReturnType) {
15 struct M : torch::nn::Module {
16 int forward() {
17 return 123;
18 }
19 };
20 AnyModule any(M{});
21 ASSERT_EQ(any.forward<int>(), 123);
22 }
23
TEST_F(AnyModuleTest,SimpleReturnTypeAndSingleArgument)24 TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
25 struct M : torch::nn::Module {
26 int forward(int x) {
27 return x;
28 }
29 };
30 AnyModule any(M{});
31 ASSERT_EQ(any.forward<int>(5), 5);
32 }
33
TEST_F(AnyModuleTest,StringLiteralReturnTypeAndArgument)34 TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
35 struct M : torch::nn::Module {
36 const char* forward(const char* x) {
37 return x;
38 }
39 };
40 AnyModule any(M{});
41 ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
42 }
43
TEST_F(AnyModuleTest,StringReturnTypeWithConstArgument)44 TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
45 struct M : torch::nn::Module {
46 std::string forward(int x, const double f) {
47 return std::to_string(static_cast<int>(x + f));
48 }
49 };
50 AnyModule any(M{});
51 int x = 4;
52 ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
53 }
54
TEST_F(AnyModuleTest,TensorReturnTypeAndStringArgumentsWithFunkyQualifications)55 TEST_F(
56 AnyModuleTest,
57 TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
58 struct M : torch::nn::Module {
59 torch::Tensor forward(
60 std::string a,
61 const std::string& b,
62 std::string&& c) {
63 const auto s = a + b + c;
64 return torch::ones({static_cast<int64_t>(s.size())});
65 }
66 };
67 AnyModule any(M{});
68 ASSERT_TRUE(
69 any.forward(std::string("a"), std::string("ab"), std::string("abc"))
70 .sum()
71 .item<int32_t>() == 6);
72 }
73
TEST_F(AnyModuleTest,WrongArgumentType)74 TEST_F(AnyModuleTest, WrongArgumentType) {
75 struct M : torch::nn::Module {
76 int forward(float x) {
77 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
78 return x;
79 }
80 };
81 AnyModule any(M{});
82 ASSERT_THROWS_WITH(
83 any.forward(5.0),
84 "Expected argument #0 to be of type float, "
85 "but received value of type double");
86 }
87
88 struct M_test_wrong_number_of_arguments : torch::nn::Module {
forwardM_test_wrong_number_of_arguments89 int forward(int a, int b) {
90 return a + b;
91 }
92 };
93
TEST_F(AnyModuleTest,WrongNumberOfArguments)94 TEST_F(AnyModuleTest, WrongNumberOfArguments) {
95 AnyModule any(M_test_wrong_number_of_arguments{});
96 #if defined(_MSC_VER)
97 std::string module_name = "struct M_test_wrong_number_of_arguments";
98 #else
99 std::string module_name = "M_test_wrong_number_of_arguments";
100 #endif
101 ASSERT_THROWS_WITH(
102 any.forward(),
103 module_name +
104 "'s forward() method expects 2 argument(s), but received 0. "
105 "If " +
106 module_name +
107 "'s forward() method has default arguments, "
108 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
109 ASSERT_THROWS_WITH(
110 any.forward(5),
111 module_name +
112 "'s forward() method expects 2 argument(s), but received 1. "
113 "If " +
114 module_name +
115 "'s forward() method has default arguments, "
116 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
117 ASSERT_THROWS_WITH(
118 any.forward(1, 2, 3),
119 module_name +
120 "'s forward() method expects 2 argument(s), but received 3.");
121 }
122
123 struct M_default_arg_with_macro : torch::nn::Module {
forwardM_default_arg_with_macro124 double forward(int a, int b = 2, double c = 3.0) {
125 return a + b + c;
126 }
127
128 protected:
129 FORWARD_HAS_DEFAULT_ARGS(
130 {1, torch::nn::AnyValue(2)},
131 {2, torch::nn::AnyValue(3.0)})
132 };
133
134 struct M_default_arg_without_macro : torch::nn::Module {
forwardM_default_arg_without_macro135 double forward(int a, int b = 2, double c = 3.0) {
136 return a + b + c;
137 }
138 };
139
TEST_F(AnyModuleTest,PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod)140 TEST_F(
141 AnyModuleTest,
142 PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
143 {
144 AnyModule any(M_default_arg_with_macro{});
145
146 ASSERT_EQ(any.forward<double>(1), 6.0);
147 ASSERT_EQ(any.forward<double>(1, 3), 7.0);
148 ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
149
150 ASSERT_THROWS_WITH(
151 any.forward(),
152 "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
153 ASSERT_THROWS_WITH(
154 any.forward(1, 2, 3.0, 4),
155 "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
156 }
157 {
158 AnyModule any(M_default_arg_without_macro{});
159
160 ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
161
162 #if defined(_MSC_VER)
163 std::string module_name = "struct M_default_arg_without_macro";
164 #else
165 std::string module_name = "M_default_arg_without_macro";
166 #endif
167
168 ASSERT_THROWS_WITH(
169 any.forward(),
170 module_name +
171 "'s forward() method expects 3 argument(s), but received 0. "
172 "If " +
173 module_name +
174 "'s forward() method has default arguments, "
175 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
176 ASSERT_THROWS_WITH(
177 any.forward<double>(1),
178 module_name +
179 "'s forward() method expects 3 argument(s), but received 1. "
180 "If " +
181 module_name +
182 "'s forward() method has default arguments, "
183 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
184 ASSERT_THROWS_WITH(
185 any.forward<double>(1, 3),
186 module_name +
187 "'s forward() method expects 3 argument(s), but received 2. "
188 "If " +
189 module_name +
190 "'s forward() method has default arguments, "
191 "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
192 ASSERT_THROWS_WITH(
193 any.forward(1, 2, 3.0, 4),
194 module_name +
195 "'s forward() method expects 3 argument(s), but received 4.");
196 }
197 }
198
199 struct M : torch::nn::Module {
MM200 explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
201 int value;
forwardM202 int forward(float x) {
203 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
204 return x;
205 }
206 };
207
TEST_F(AnyModuleTest,GetWithCorrectTypeSucceeds)208 TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
209 AnyModule any(M{5});
210 ASSERT_EQ(any.get<M>().value, 5);
211 }
212
TEST_F(AnyModuleTest,GetWithIncorrectTypeThrows)213 TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
214 struct N : torch::nn::Module {
215 torch::Tensor forward(torch::Tensor input) {
216 return input;
217 }
218 };
219 AnyModule any(M{5});
220 ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
221 }
222
TEST_F(AnyModuleTest,PtrWithBaseClassSucceeds)223 TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
224 AnyModule any(M{5});
225 auto ptr = any.ptr();
226 ASSERT_NE(ptr, nullptr);
227 ASSERT_EQ(ptr->name(), "M");
228 }
229
TEST_F(AnyModuleTest,PtrWithGoodDowncastSuccceeds)230 TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
231 AnyModule any(M{5});
232 auto ptr = any.ptr<M>();
233 ASSERT_NE(ptr, nullptr);
234 ASSERT_EQ(ptr->value, 5);
235 }
236
TEST_F(AnyModuleTest,PtrWithBadDowncastThrows)237 TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
238 struct N : torch::nn::Module {
239 torch::Tensor forward(torch::Tensor input) {
240 return input;
241 }
242 };
243 AnyModule any(M{5});
244 ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
245 }
246
TEST_F(AnyModuleTest,DefaultStateIsEmpty)247 TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
248 struct M : torch::nn::Module {
249 explicit M(int value_) : value(value_) {}
250 int value;
251 int forward(float x) {
252 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
253 return x;
254 }
255 };
256 AnyModule any;
257 ASSERT_TRUE(any.is_empty());
258 any = std::make_shared<M>(5);
259 ASSERT_FALSE(any.is_empty());
260 ASSERT_EQ(any.get<M>().value, 5);
261 }
262
TEST_F(AnyModuleTest,AllMethodsThrowForEmptyAnyModule)263 TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
264 struct M : torch::nn::Module {
265 int forward(int x) {
266 return x;
267 }
268 };
269 AnyModule any;
270 ASSERT_TRUE(any.is_empty());
271 ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
272 ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
273 ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
274 ASSERT_THROWS_WITH(
275 any.type_info(), "Cannot call type_info() on an empty AnyModule");
276 ASSERT_THROWS_WITH(
277 any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
278 }
279
TEST_F(AnyModuleTest,CanMoveAssignDifferentModules)280 TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
281 struct M : torch::nn::Module {
282 std::string forward(int x) {
283 return std::to_string(x);
284 }
285 };
286 struct N : torch::nn::Module {
287 int forward(float x) {
288 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
289 return 3 + x;
290 }
291 };
292 AnyModule any;
293 ASSERT_TRUE(any.is_empty());
294 any = std::make_shared<M>();
295 ASSERT_FALSE(any.is_empty());
296 ASSERT_EQ(any.forward<std::string>(5), "5");
297 any = std::make_shared<N>();
298 ASSERT_FALSE(any.is_empty());
299 ASSERT_EQ(any.forward<int>(5.0f), 8);
300 }
301
TEST_F(AnyModuleTest,ConstructsFromModuleHolder)302 TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
303 struct MImpl : torch::nn::Module {
304 explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
305 int value;
306 int forward(float x) {
307 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
308 return x;
309 }
310 };
311
312 struct M : torch::nn::ModuleHolder<MImpl> {
313 using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
314 using torch::nn::ModuleHolder<MImpl>::get;
315 };
316
317 AnyModule any(M{5});
318 ASSERT_EQ(any.get<MImpl>().value, 5);
319 ASSERT_EQ(any.get<M>()->value, 5);
320
321 AnyModule module(Linear(3, 4));
322 std::shared_ptr<Module> ptr = module.ptr();
323 Linear linear(module.get<Linear>());
324 }
325
TEST_F(AnyModuleTest,ConvertsVariableToTensorCorrectly)326 TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
327 struct M : torch::nn::Module {
328 torch::Tensor forward(torch::Tensor input) {
329 return input;
330 }
331 };
332
333 // When you have an autograd::Variable, it should be converted to a
334 // torch::Tensor before being passed to the function (to avoid a type
335 // mismatch).
336 AnyModule any(M{});
337 ASSERT_TRUE(
338 any.forward(torch::autograd::Variable(torch::ones(5)))
339 .sum()
340 .item<float>() == 5);
341 // at::Tensors that are not variables work too.
342 ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
343 }
344
345 namespace torch {
346 namespace nn {
347 struct TestAnyValue {
348 template <typename T>
349 // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
TestAnyValuetorch::nn::TestAnyValue350 explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
operator ()torch::nn::TestAnyValue351 AnyValue operator()() {
352 return std::move(value_);
353 }
354 AnyValue value_;
355 };
356 template <typename T>
make_value(T && value)357 AnyValue make_value(T&& value) {
358 return TestAnyValue(std::forward<T>(value))();
359 }
360 } // namespace nn
361 } // namespace torch
362
363 struct AnyValueTest : torch::test::SeedingFixture {};
364
TEST_F(AnyValueTest,CorrectlyAccessesIntWhenCorrectType)365 TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
366 auto value = make_value<int>(5);
367 ASSERT_NE(value.try_get<int>(), nullptr);
368 // const and non-const types have the same typeid(),
369 // but casting Holder<int> to Holder<const int> is undefined
370 // behavior according to UBSAN:
371 // https://github.com/pytorch/pytorch/issues/26964
372 // ASSERT_NE(value.try_get<const int>(), nullptr);
373 ASSERT_EQ(value.get<int>(), 5);
374 }
375 // This test does not work at all, because it looks like make_value
376 // decays const int into int.
377 // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
378 // auto value = make_value<const int>(5);
379 // ASSERT_NE(value.try_get<const int>(), nullptr);
380 // // ASSERT_NE(value.try_get<int>(), nullptr);
381 // ASSERT_EQ(value.get<const int>(), 5);
382 //}
TEST_F(AnyValueTest,CorrectlyAccessesStringLiteralWhenCorrectType)383 TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
384 auto value = make_value("hello");
385 ASSERT_NE(value.try_get<const char*>(), nullptr);
386 ASSERT_EQ(value.get<const char*>(), std::string("hello"));
387 }
TEST_F(AnyValueTest,CorrectlyAccessesStringWhenCorrectType)388 TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
389 auto value = make_value(std::string("hello"));
390 ASSERT_NE(value.try_get<std::string>(), nullptr);
391 ASSERT_EQ(value.get<std::string>(), "hello");
392 }
TEST_F(AnyValueTest,CorrectlyAccessesPointersWhenCorrectType)393 TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
394 std::string s("hello");
395 std::string* p = &s;
396 auto value = make_value(p);
397 ASSERT_NE(value.try_get<std::string*>(), nullptr);
398 ASSERT_EQ(*value.get<std::string*>(), "hello");
399 }
TEST_F(AnyValueTest,CorrectlyAccessesReferencesWhenCorrectType)400 TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
401 std::string s("hello");
402 const std::string& t = s;
403 auto value = make_value(t);
404 ASSERT_NE(value.try_get<std::string>(), nullptr);
405 ASSERT_EQ(value.get<std::string>(), "hello");
406 }
407
TEST_F(AnyValueTest,TryGetReturnsNullptrForTheWrongType)408 TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
409 auto value = make_value(5);
410 ASSERT_NE(value.try_get<int>(), nullptr);
411 ASSERT_EQ(value.try_get<float>(), nullptr);
412 ASSERT_EQ(value.try_get<long>(), nullptr);
413 ASSERT_EQ(value.try_get<std::string>(), nullptr);
414 }
415
TEST_F(AnyValueTest,GetThrowsForTheWrongType)416 TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
417 auto value = make_value(5);
418 ASSERT_NE(value.try_get<int>(), nullptr);
419 ASSERT_THROWS_WITH(
420 value.get<float>(),
421 "Attempted to cast AnyValue to float, "
422 "but its actual type is int");
423 ASSERT_THROWS_WITH(
424 value.get<long>(),
425 "Attempted to cast AnyValue to long, "
426 "but its actual type is int");
427 }
428
TEST_F(AnyValueTest,MoveConstructionIsAllowed)429 TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
430 auto value = make_value(5);
431 auto copy = make_value(std::move(value));
432 ASSERT_NE(copy.try_get<int>(), nullptr);
433 ASSERT_EQ(copy.get<int>(), 5);
434 }
435
TEST_F(AnyValueTest,MoveAssignmentIsAllowed)436 TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
437 auto value = make_value(5);
438 auto copy = make_value(10);
439 copy = std::move(value);
440 ASSERT_NE(copy.try_get<int>(), nullptr);
441 ASSERT_EQ(copy.get<int>(), 5);
442 }
443
TEST_F(AnyValueTest,TypeInfoIsCorrectForInt)444 TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
445 auto value = make_value(5);
446 ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
447 }
448
TEST_F(AnyValueTest,TypeInfoIsCorrectForStringLiteral)449 TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
450 auto value = make_value("hello");
451 ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
452 }
453
TEST_F(AnyValueTest,TypeInfoIsCorrectForString)454 TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
455 auto value = make_value(std::string("hello"));
456 ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
457 }
458