1 #include <gtest/gtest.h>
2
3 #include <ATen/core/NestedIntSymNodeImpl.h>
4 #include <c10/core/SymInt.h>
5 #include <c10/core/SymNodeImpl.h>
6 #include <torch/torch.h>
7
8 #include <test/cpp/api/support.h>
9
TEST(NestedIntTest,Comparisons)10 TEST(NestedIntTest, Comparisons) {
11 auto a = c10::SymInt(
12 c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
13 auto b = c10::SymInt(
14 c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
15 auto c = c10::SymInt(
16 c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
17 auto d = c10::SymInt(3);
18
19 ASSERT_TRUE(a == a);
20 ASSERT_TRUE(a == b);
21 ASSERT_FALSE(a != a);
22 ASSERT_FALSE(a != b);
23 ASSERT_FALSE(a == c);
24 ASSERT_TRUE(a != c);
25
26 ASSERT_FALSE(a == d);
27 ASSERT_TRUE(a != d);
28 ASSERT_FALSE(d == a);
29 ASSERT_TRUE(d != a);
30
31 // ge
32 ASSERT_TRUE(a >= a);
33 ASSERT_TRUE(a >= b);
34 ASSERT_TRUE(b >= a);
35 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
36 EXPECT_THROW((void)(a >= c), c10::Error);
37 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
38 EXPECT_THROW((void)(c >= a), c10::Error);
39 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
40 EXPECT_THROW((void)(c >= 3), c10::Error);
41 ASSERT_TRUE(c >= 2);
42 ASSERT_TRUE(c >= 1);
43 ASSERT_FALSE(1 >= c);
44
45 // lt
46 ASSERT_FALSE(a < a);
47 ASSERT_FALSE(a < b);
48 ASSERT_FALSE(b < a);
49 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
50 EXPECT_THROW((void)(a < c), c10::Error);
51 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
52 EXPECT_THROW((void)(c < a), c10::Error);
53 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
54 EXPECT_THROW((void)(3 < a), c10::Error);
55 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
56 EXPECT_THROW((void)(2 < a), c10::Error);
57 ASSERT_TRUE(1 < a);
58
59 // le
60 ASSERT_TRUE(a <= a);
61 ASSERT_TRUE(b <= a);
62 ASSERT_TRUE(a <= b);
63 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64 EXPECT_THROW((void)(a <= c), c10::Error);
65 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66 EXPECT_THROW((void)(c <= a), c10::Error);
67 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
68 EXPECT_THROW((void)(3 <= c), c10::Error);
69 ASSERT_TRUE(2 <= c);
70 ASSERT_TRUE(1 <= c);
71 ASSERT_FALSE(c <= 1);
72
73 // gt
74 ASSERT_FALSE(a > a);
75 ASSERT_FALSE(b > a);
76 ASSERT_FALSE(a > b);
77 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
78 EXPECT_THROW((void)(a > c), c10::Error);
79 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
80 EXPECT_THROW((void)(c > a), c10::Error);
81 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
82 EXPECT_THROW((void)(a > 3), c10::Error);
83 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
84 EXPECT_THROW((void)(a > 2), c10::Error);
85 ASSERT_TRUE(a > 1);
86 }
87
TEST(NestedIntTest,WithFactor)88 TEST(NestedIntTest, WithFactor) {
89 auto a = c10::SymInt(
90 c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
91 auto b = c10::SymInt(
92 c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
93 // eq
94 ASSERT_FALSE(a == b);
95 ASSERT_FALSE(a >= b);
96 ASSERT_TRUE(b >= a);
97 ASSERT_TRUE(a <= b);
98 ASSERT_FALSE(b <= a);
99 // ne
100 ASSERT_TRUE(a != b);
101 // mul
102 ASSERT_TRUE(a * 2 == b);
103 ASSERT_TRUE(a * 3 >= b);
104 ASSERT_TRUE(a * 2 == 2 * a);
105 }
106