1 #include <gtest/gtest.h>
2
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5
6 #include <test/cpp/api/support.h>
7
8 #include <cstddef>
9 #include <initializer_list>
10 #include <vector>
11
12 struct ExpandingArrayTest : torch::test::SeedingFixture {};
13
TEST_F(ExpandingArrayTest,CanConstructFromInitializerList)14 TEST_F(ExpandingArrayTest, CanConstructFromInitializerList) {
15 torch::ExpandingArray<5> e({1, 2, 3, 4, 5});
16 ASSERT_EQ(e.size(), 5);
17 for (const auto i : c10::irange(e.size())) {
18 ASSERT_EQ((*e)[i], i + 1);
19 }
20 }
21
TEST_F(ExpandingArrayTest,CanConstructFromVector)22 TEST_F(ExpandingArrayTest, CanConstructFromVector) {
23 torch::ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
24 ASSERT_EQ(e.size(), 5);
25 for (const auto i : c10::irange(e.size())) {
26 ASSERT_EQ((*e)[i], i + 1);
27 }
28 }
29
TEST_F(ExpandingArrayTest,CanConstructFromArray)30 TEST_F(ExpandingArrayTest, CanConstructFromArray) {
31 torch::ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
32 ASSERT_EQ(e.size(), 5);
33 for (const auto i : c10::irange(e.size())) {
34 ASSERT_EQ((*e)[i], i + 1);
35 }
36 }
37
TEST_F(ExpandingArrayTest,CanConstructFromSingleValue)38 TEST_F(ExpandingArrayTest, CanConstructFromSingleValue) {
39 torch::ExpandingArray<5> e(5);
40 ASSERT_EQ(e.size(), 5);
41 for (const auto i : c10::irange(e.size())) {
42 ASSERT_EQ((*e)[i], 5);
43 }
44 }
45
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList)46 TEST_F(
47 ExpandingArrayTest,
48 ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList) {
49 ASSERT_THROWS_WITH(
50 torch::ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
51 "Expected 5 values, but instead got 7");
52 }
53
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector)54 TEST_F(
55 ExpandingArrayTest,
56 ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector) {
57 ASSERT_THROWS_WITH(
58 torch::ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
59 "Expected 5 values, but instead got 7");
60 }
61