xref: /aosp_15_r20/external/pytorch/test/cpp/api/any.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 
5 #include <test/cpp/api/support.h>
6 
7 #include <algorithm>
8 #include <string>
9 
10 using namespace torch::nn;
11 
12 struct AnyModuleTest : torch::test::SeedingFixture {};
13 
TEST_F(AnyModuleTest,SimpleReturnType)14 TEST_F(AnyModuleTest, SimpleReturnType) {
15   struct M : torch::nn::Module {
16     int forward() {
17       return 123;
18     }
19   };
20   AnyModule any(M{});
21   ASSERT_EQ(any.forward<int>(), 123);
22 }
23 
TEST_F(AnyModuleTest,SimpleReturnTypeAndSingleArgument)24 TEST_F(AnyModuleTest, SimpleReturnTypeAndSingleArgument) {
25   struct M : torch::nn::Module {
26     int forward(int x) {
27       return x;
28     }
29   };
30   AnyModule any(M{});
31   ASSERT_EQ(any.forward<int>(5), 5);
32 }
33 
TEST_F(AnyModuleTest,StringLiteralReturnTypeAndArgument)34 TEST_F(AnyModuleTest, StringLiteralReturnTypeAndArgument) {
35   struct M : torch::nn::Module {
36     const char* forward(const char* x) {
37       return x;
38     }
39   };
40   AnyModule any(M{});
41   ASSERT_EQ(any.forward<const char*>("hello"), std::string("hello"));
42 }
43 
TEST_F(AnyModuleTest,StringReturnTypeWithConstArgument)44 TEST_F(AnyModuleTest, StringReturnTypeWithConstArgument) {
45   struct M : torch::nn::Module {
46     std::string forward(int x, const double f) {
47       return std::to_string(static_cast<int>(x + f));
48     }
49   };
50   AnyModule any(M{});
51   int x = 4;
52   ASSERT_EQ(any.forward<std::string>(x, 3.14), std::string("7"));
53 }
54 
TEST_F(AnyModuleTest,TensorReturnTypeAndStringArgumentsWithFunkyQualifications)55 TEST_F(
56     AnyModuleTest,
57     TensorReturnTypeAndStringArgumentsWithFunkyQualifications) {
58   struct M : torch::nn::Module {
59     torch::Tensor forward(
60         std::string a,
61         const std::string& b,
62         std::string&& c) {
63       const auto s = a + b + c;
64       return torch::ones({static_cast<int64_t>(s.size())});
65     }
66   };
67   AnyModule any(M{});
68   ASSERT_TRUE(
69       any.forward(std::string("a"), std::string("ab"), std::string("abc"))
70           .sum()
71           .item<int32_t>() == 6);
72 }
73 
TEST_F(AnyModuleTest,WrongArgumentType)74 TEST_F(AnyModuleTest, WrongArgumentType) {
75   struct M : torch::nn::Module {
76     int forward(float x) {
77       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
78       return x;
79     }
80   };
81   AnyModule any(M{});
82   ASSERT_THROWS_WITH(
83       any.forward(5.0),
84       "Expected argument #0 to be of type float, "
85       "but received value of type double");
86 }
87 
88 struct M_test_wrong_number_of_arguments : torch::nn::Module {
forwardM_test_wrong_number_of_arguments89   int forward(int a, int b) {
90     return a + b;
91   }
92 };
93 
TEST_F(AnyModuleTest,WrongNumberOfArguments)94 TEST_F(AnyModuleTest, WrongNumberOfArguments) {
95   AnyModule any(M_test_wrong_number_of_arguments{});
96 #if defined(_MSC_VER)
97   std::string module_name = "struct M_test_wrong_number_of_arguments";
98 #else
99   std::string module_name = "M_test_wrong_number_of_arguments";
100 #endif
101   ASSERT_THROWS_WITH(
102       any.forward(),
103       module_name +
104           "'s forward() method expects 2 argument(s), but received 0. "
105           "If " +
106           module_name +
107           "'s forward() method has default arguments, "
108           "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
109   ASSERT_THROWS_WITH(
110       any.forward(5),
111       module_name +
112           "'s forward() method expects 2 argument(s), but received 1. "
113           "If " +
114           module_name +
115           "'s forward() method has default arguments, "
116           "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
117   ASSERT_THROWS_WITH(
118       any.forward(1, 2, 3),
119       module_name +
120           "'s forward() method expects 2 argument(s), but received 3.");
121 }
122 
123 struct M_default_arg_with_macro : torch::nn::Module {
forwardM_default_arg_with_macro124   double forward(int a, int b = 2, double c = 3.0) {
125     return a + b + c;
126   }
127 
128  protected:
129   FORWARD_HAS_DEFAULT_ARGS(
130       {1, torch::nn::AnyValue(2)},
131       {2, torch::nn::AnyValue(3.0)})
132 };
133 
134 struct M_default_arg_without_macro : torch::nn::Module {
forwardM_default_arg_without_macro135   double forward(int a, int b = 2, double c = 3.0) {
136     return a + b + c;
137   }
138 };
139 
TEST_F(AnyModuleTest,PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod)140 TEST_F(
141     AnyModuleTest,
142     PassingArgumentsToModuleWithDefaultArgumentsInForwardMethod) {
143   {
144     AnyModule any(M_default_arg_with_macro{});
145 
146     ASSERT_EQ(any.forward<double>(1), 6.0);
147     ASSERT_EQ(any.forward<double>(1, 3), 7.0);
148     ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
149 
150     ASSERT_THROWS_WITH(
151         any.forward(),
152         "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 0.");
153     ASSERT_THROWS_WITH(
154         any.forward(1, 2, 3.0, 4),
155         "M_default_arg_with_macro's forward() method expects at least 1 argument(s) and at most 3 argument(s), but received 4.");
156   }
157   {
158     AnyModule any(M_default_arg_without_macro{});
159 
160     ASSERT_EQ(any.forward<double>(1, 3, 5.0), 9.0);
161 
162 #if defined(_MSC_VER)
163     std::string module_name = "struct M_default_arg_without_macro";
164 #else
165     std::string module_name = "M_default_arg_without_macro";
166 #endif
167 
168     ASSERT_THROWS_WITH(
169         any.forward(),
170         module_name +
171             "'s forward() method expects 3 argument(s), but received 0. "
172             "If " +
173             module_name +
174             "'s forward() method has default arguments, "
175             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
176     ASSERT_THROWS_WITH(
177         any.forward<double>(1),
178         module_name +
179             "'s forward() method expects 3 argument(s), but received 1. "
180             "If " +
181             module_name +
182             "'s forward() method has default arguments, "
183             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
184     ASSERT_THROWS_WITH(
185         any.forward<double>(1, 3),
186         module_name +
187             "'s forward() method expects 3 argument(s), but received 2. "
188             "If " +
189             module_name +
190             "'s forward() method has default arguments, "
191             "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.");
192     ASSERT_THROWS_WITH(
193         any.forward(1, 2, 3.0, 4),
194         module_name +
195             "'s forward() method expects 3 argument(s), but received 4.");
196   }
197 }
198 
199 struct M : torch::nn::Module {
MM200   explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
201   int value;
forwardM202   int forward(float x) {
203     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
204     return x;
205   }
206 };
207 
TEST_F(AnyModuleTest,GetWithCorrectTypeSucceeds)208 TEST_F(AnyModuleTest, GetWithCorrectTypeSucceeds) {
209   AnyModule any(M{5});
210   ASSERT_EQ(any.get<M>().value, 5);
211 }
212 
TEST_F(AnyModuleTest,GetWithIncorrectTypeThrows)213 TEST_F(AnyModuleTest, GetWithIncorrectTypeThrows) {
214   struct N : torch::nn::Module {
215     torch::Tensor forward(torch::Tensor input) {
216       return input;
217     }
218   };
219   AnyModule any(M{5});
220   ASSERT_THROWS_WITH(any.get<N>(), "Attempted to cast module");
221 }
222 
TEST_F(AnyModuleTest,PtrWithBaseClassSucceeds)223 TEST_F(AnyModuleTest, PtrWithBaseClassSucceeds) {
224   AnyModule any(M{5});
225   auto ptr = any.ptr();
226   ASSERT_NE(ptr, nullptr);
227   ASSERT_EQ(ptr->name(), "M");
228 }
229 
TEST_F(AnyModuleTest,PtrWithGoodDowncastSuccceeds)230 TEST_F(AnyModuleTest, PtrWithGoodDowncastSuccceeds) {
231   AnyModule any(M{5});
232   auto ptr = any.ptr<M>();
233   ASSERT_NE(ptr, nullptr);
234   ASSERT_EQ(ptr->value, 5);
235 }
236 
TEST_F(AnyModuleTest,PtrWithBadDowncastThrows)237 TEST_F(AnyModuleTest, PtrWithBadDowncastThrows) {
238   struct N : torch::nn::Module {
239     torch::Tensor forward(torch::Tensor input) {
240       return input;
241     }
242   };
243   AnyModule any(M{5});
244   ASSERT_THROWS_WITH(any.ptr<N>(), "Attempted to cast module");
245 }
246 
TEST_F(AnyModuleTest,DefaultStateIsEmpty)247 TEST_F(AnyModuleTest, DefaultStateIsEmpty) {
248   struct M : torch::nn::Module {
249     explicit M(int value_) : value(value_) {}
250     int value;
251     int forward(float x) {
252       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
253       return x;
254     }
255   };
256   AnyModule any;
257   ASSERT_TRUE(any.is_empty());
258   any = std::make_shared<M>(5);
259   ASSERT_FALSE(any.is_empty());
260   ASSERT_EQ(any.get<M>().value, 5);
261 }
262 
TEST_F(AnyModuleTest,AllMethodsThrowForEmptyAnyModule)263 TEST_F(AnyModuleTest, AllMethodsThrowForEmptyAnyModule) {
264   struct M : torch::nn::Module {
265     int forward(int x) {
266       return x;
267     }
268   };
269   AnyModule any;
270   ASSERT_TRUE(any.is_empty());
271   ASSERT_THROWS_WITH(any.get<M>(), "Cannot call get() on an empty AnyModule");
272   ASSERT_THROWS_WITH(any.ptr<M>(), "Cannot call ptr() on an empty AnyModule");
273   ASSERT_THROWS_WITH(any.ptr(), "Cannot call ptr() on an empty AnyModule");
274   ASSERT_THROWS_WITH(
275       any.type_info(), "Cannot call type_info() on an empty AnyModule");
276   ASSERT_THROWS_WITH(
277       any.forward<int>(5), "Cannot call forward() on an empty AnyModule");
278 }
279 
TEST_F(AnyModuleTest,CanMoveAssignDifferentModules)280 TEST_F(AnyModuleTest, CanMoveAssignDifferentModules) {
281   struct M : torch::nn::Module {
282     std::string forward(int x) {
283       return std::to_string(x);
284     }
285   };
286   struct N : torch::nn::Module {
287     int forward(float x) {
288       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
289       return 3 + x;
290     }
291   };
292   AnyModule any;
293   ASSERT_TRUE(any.is_empty());
294   any = std::make_shared<M>();
295   ASSERT_FALSE(any.is_empty());
296   ASSERT_EQ(any.forward<std::string>(5), "5");
297   any = std::make_shared<N>();
298   ASSERT_FALSE(any.is_empty());
299   ASSERT_EQ(any.forward<int>(5.0f), 8);
300 }
301 
TEST_F(AnyModuleTest,ConstructsFromModuleHolder)302 TEST_F(AnyModuleTest, ConstructsFromModuleHolder) {
303   struct MImpl : torch::nn::Module {
304     explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
305     int value;
306     int forward(float x) {
307       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
308       return x;
309     }
310   };
311 
312   struct M : torch::nn::ModuleHolder<MImpl> {
313     using torch::nn::ModuleHolder<MImpl>::ModuleHolder;
314     using torch::nn::ModuleHolder<MImpl>::get;
315   };
316 
317   AnyModule any(M{5});
318   ASSERT_EQ(any.get<MImpl>().value, 5);
319   ASSERT_EQ(any.get<M>()->value, 5);
320 
321   AnyModule module(Linear(3, 4));
322   std::shared_ptr<Module> ptr = module.ptr();
323   Linear linear(module.get<Linear>());
324 }
325 
TEST_F(AnyModuleTest,ConvertsVariableToTensorCorrectly)326 TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
327   struct M : torch::nn::Module {
328     torch::Tensor forward(torch::Tensor input) {
329       return input;
330     }
331   };
332 
333   // When you have an autograd::Variable, it should be converted to a
334   // torch::Tensor before being passed to the function (to avoid a type
335   // mismatch).
336   AnyModule any(M{});
337   ASSERT_TRUE(
338       any.forward(torch::autograd::Variable(torch::ones(5)))
339           .sum()
340           .item<float>() == 5);
341   // at::Tensors that are not variables work too.
342   ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
343 }
344 
345 namespace torch {
346 namespace nn {
347 struct TestAnyValue {
348   template <typename T>
349   // NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
TestAnyValuetorch::nn::TestAnyValue350   explicit TestAnyValue(T&& value) : value_(std::forward<T>(value)) {}
operator ()torch::nn::TestAnyValue351   AnyValue operator()() {
352     return std::move(value_);
353   }
354   AnyValue value_;
355 };
356 template <typename T>
make_value(T && value)357 AnyValue make_value(T&& value) {
358   return TestAnyValue(std::forward<T>(value))();
359 }
360 } // namespace nn
361 } // namespace torch
362 
363 struct AnyValueTest : torch::test::SeedingFixture {};
364 
TEST_F(AnyValueTest,CorrectlyAccessesIntWhenCorrectType)365 TEST_F(AnyValueTest, CorrectlyAccessesIntWhenCorrectType) {
366   auto value = make_value<int>(5);
367   ASSERT_NE(value.try_get<int>(), nullptr);
368   // const and non-const types have the same typeid(),
369   // but casting Holder<int> to Holder<const int> is undefined
370   // behavior according to UBSAN:
371   // https://github.com/pytorch/pytorch/issues/26964
372   // ASSERT_NE(value.try_get<const int>(), nullptr);
373   ASSERT_EQ(value.get<int>(), 5);
374 }
375 // This test does not work at all, because it looks like make_value
376 // decays const int into int.
377 // TEST_F(AnyValueTest, CorrectlyAccessesConstIntWhenCorrectType) {
378 //  auto value = make_value<const int>(5);
379 //  ASSERT_NE(value.try_get<const int>(), nullptr);
380 //  // ASSERT_NE(value.try_get<int>(), nullptr);
381 //  ASSERT_EQ(value.get<const int>(), 5);
382 //}
TEST_F(AnyValueTest,CorrectlyAccessesStringLiteralWhenCorrectType)383 TEST_F(AnyValueTest, CorrectlyAccessesStringLiteralWhenCorrectType) {
384   auto value = make_value("hello");
385   ASSERT_NE(value.try_get<const char*>(), nullptr);
386   ASSERT_EQ(value.get<const char*>(), std::string("hello"));
387 }
TEST_F(AnyValueTest,CorrectlyAccessesStringWhenCorrectType)388 TEST_F(AnyValueTest, CorrectlyAccessesStringWhenCorrectType) {
389   auto value = make_value(std::string("hello"));
390   ASSERT_NE(value.try_get<std::string>(), nullptr);
391   ASSERT_EQ(value.get<std::string>(), "hello");
392 }
TEST_F(AnyValueTest,CorrectlyAccessesPointersWhenCorrectType)393 TEST_F(AnyValueTest, CorrectlyAccessesPointersWhenCorrectType) {
394   std::string s("hello");
395   std::string* p = &s;
396   auto value = make_value(p);
397   ASSERT_NE(value.try_get<std::string*>(), nullptr);
398   ASSERT_EQ(*value.get<std::string*>(), "hello");
399 }
TEST_F(AnyValueTest,CorrectlyAccessesReferencesWhenCorrectType)400 TEST_F(AnyValueTest, CorrectlyAccessesReferencesWhenCorrectType) {
401   std::string s("hello");
402   const std::string& t = s;
403   auto value = make_value(t);
404   ASSERT_NE(value.try_get<std::string>(), nullptr);
405   ASSERT_EQ(value.get<std::string>(), "hello");
406 }
407 
TEST_F(AnyValueTest,TryGetReturnsNullptrForTheWrongType)408 TEST_F(AnyValueTest, TryGetReturnsNullptrForTheWrongType) {
409   auto value = make_value(5);
410   ASSERT_NE(value.try_get<int>(), nullptr);
411   ASSERT_EQ(value.try_get<float>(), nullptr);
412   ASSERT_EQ(value.try_get<long>(), nullptr);
413   ASSERT_EQ(value.try_get<std::string>(), nullptr);
414 }
415 
TEST_F(AnyValueTest,GetThrowsForTheWrongType)416 TEST_F(AnyValueTest, GetThrowsForTheWrongType) {
417   auto value = make_value(5);
418   ASSERT_NE(value.try_get<int>(), nullptr);
419   ASSERT_THROWS_WITH(
420       value.get<float>(),
421       "Attempted to cast AnyValue to float, "
422       "but its actual type is int");
423   ASSERT_THROWS_WITH(
424       value.get<long>(),
425       "Attempted to cast AnyValue to long, "
426       "but its actual type is int");
427 }
428 
TEST_F(AnyValueTest,MoveConstructionIsAllowed)429 TEST_F(AnyValueTest, MoveConstructionIsAllowed) {
430   auto value = make_value(5);
431   auto copy = make_value(std::move(value));
432   ASSERT_NE(copy.try_get<int>(), nullptr);
433   ASSERT_EQ(copy.get<int>(), 5);
434 }
435 
TEST_F(AnyValueTest,MoveAssignmentIsAllowed)436 TEST_F(AnyValueTest, MoveAssignmentIsAllowed) {
437   auto value = make_value(5);
438   auto copy = make_value(10);
439   copy = std::move(value);
440   ASSERT_NE(copy.try_get<int>(), nullptr);
441   ASSERT_EQ(copy.get<int>(), 5);
442 }
443 
TEST_F(AnyValueTest,TypeInfoIsCorrectForInt)444 TEST_F(AnyValueTest, TypeInfoIsCorrectForInt) {
445   auto value = make_value(5);
446   ASSERT_EQ(value.type_info().hash_code(), typeid(int).hash_code());
447 }
448 
TEST_F(AnyValueTest,TypeInfoIsCorrectForStringLiteral)449 TEST_F(AnyValueTest, TypeInfoIsCorrectForStringLiteral) {
450   auto value = make_value("hello");
451   ASSERT_EQ(value.type_info().hash_code(), typeid(const char*).hash_code());
452 }
453 
TEST_F(AnyValueTest,TypeInfoIsCorrectForString)454 TEST_F(AnyValueTest, TypeInfoIsCorrectForString) {
455   auto value = make_value(std::string("hello"));
456   ASSERT_EQ(value.type_info().hash_code(), typeid(std::string).hash_code());
457 }
458