xref: /aosp_15_r20/external/pytorch/test/cpp/api/expanding-array.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <cstddef>
9*da0073e9SAndroid Build Coastguard Worker #include <initializer_list>
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker struct ExpandingArrayTest : torch::test::SeedingFixture {};
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,CanConstructFromInitializerList)14*da0073e9SAndroid Build Coastguard Worker TEST_F(ExpandingArrayTest, CanConstructFromInitializerList) {
15*da0073e9SAndroid Build Coastguard Worker   torch::ExpandingArray<5> e({1, 2, 3, 4, 5});
16*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(e.size(), 5);
17*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(e.size())) {
18*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ((*e)[i], i + 1);
19*da0073e9SAndroid Build Coastguard Worker   }
20*da0073e9SAndroid Build Coastguard Worker }
21*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,CanConstructFromVector)22*da0073e9SAndroid Build Coastguard Worker TEST_F(ExpandingArrayTest, CanConstructFromVector) {
23*da0073e9SAndroid Build Coastguard Worker   torch::ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
24*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(e.size(), 5);
25*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(e.size())) {
26*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ((*e)[i], i + 1);
27*da0073e9SAndroid Build Coastguard Worker   }
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,CanConstructFromArray)30*da0073e9SAndroid Build Coastguard Worker TEST_F(ExpandingArrayTest, CanConstructFromArray) {
31*da0073e9SAndroid Build Coastguard Worker   torch::ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
32*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(e.size(), 5);
33*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(e.size())) {
34*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ((*e)[i], i + 1);
35*da0073e9SAndroid Build Coastguard Worker   }
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,CanConstructFromSingleValue)38*da0073e9SAndroid Build Coastguard Worker TEST_F(ExpandingArrayTest, CanConstructFromSingleValue) {
39*da0073e9SAndroid Build Coastguard Worker   torch::ExpandingArray<5> e(5);
40*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(e.size(), 5);
41*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(e.size())) {
42*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ((*e)[i], 5);
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList)46*da0073e9SAndroid Build Coastguard Worker TEST_F(
47*da0073e9SAndroid Build Coastguard Worker     ExpandingArrayTest,
48*da0073e9SAndroid Build Coastguard Worker     ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInInitializerList) {
49*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
50*da0073e9SAndroid Build Coastguard Worker       torch::ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
51*da0073e9SAndroid Build Coastguard Worker       "Expected 5 values, but instead got 7");
52*da0073e9SAndroid Build Coastguard Worker }
53*da0073e9SAndroid Build Coastguard Worker 
TEST_F(ExpandingArrayTest,ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector)54*da0073e9SAndroid Build Coastguard Worker TEST_F(
55*da0073e9SAndroid Build Coastguard Worker     ExpandingArrayTest,
56*da0073e9SAndroid Build Coastguard Worker     ThrowsWhenConstructedWithIncorrectNumberOfArgumentsInVector) {
57*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
58*da0073e9SAndroid Build Coastguard Worker       torch::ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
59*da0073e9SAndroid Build Coastguard Worker       "Expected 5 values, but instead got 7");
60*da0073e9SAndroid Build Coastguard Worker }
61