xref: /aosp_15_r20/external/pytorch/test/cpp/api/expanding-array.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 #include <cstddef>
9 #include <initializer_list>
10 #include <vector>
11 
12 struct ExpandingArrayTest : torch::test::SeedingFixture {};
13 
TEST_F(ExpandingArrayTest,CanConstructFromInitializerList)14 TEST_F(ExpandingArrayTest, CanConstructFromInitializerList) {
15   torch::ExpandingArray<5> e({1, 2, 3, 4, 5});
16   ASSERT_EQ(e.size(), 5);
17   for (const auto i : c10::irange(e.size())) {
18     ASSERT_EQ((*e)[i], i + 1);
19   }
20 }
21 
TEST_F(ExpandingArrayTest,CanConstructFromVector)22 TEST_F(ExpandingArrayTest, CanConstructFromVector) {
23   torch::ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
24   ASSERT_EQ(e.size(), 5);
25   for (const auto i : c10::irange(e.size())) {
26     ASSERT_EQ((*e)[i], i + 1);
27   }
28 }
29 
TEST_F(ExpandingArrayTest,CanConstructFromArray)30 TEST_F(ExpandingArrayTest, CanConstructFromArray) {
31   torch::ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
32   ASSERT_EQ(e.size(), 5);
33   for (const auto i : c10::irange(e.size())) {
34     ASSERT_EQ((*e)[i], i + 1);
35   }
36 }
37 
TEST_F(ExpandingArrayTest,CanConstructFromSingleValue)38 TEST_F(ExpandingArrayTest, CanConstructFromSingleValue) {
39   torch::ExpandingArray<5> e(5);
40   ASSERT_EQ(e.size(), 5);
41   for (const auto i : c10::irange(e.size())) {
42     ASSERT_EQ((*e)[i], 5);
43   }
44 }
45 
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList)46 TEST_F(
47     ExpandingArrayTest,
48     ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList) {
49   ASSERT_THROWS_WITH(
50       torch::ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
51       "Expected 5 values, but instead got 7");
52 }
53 
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector)54 TEST_F(
55     ExpandingArrayTest,
56     ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector) {
57   ASSERT_THROWS_WITH(
58       torch::ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
59       "Expected 5 values, but instead got 7");
60 }
61