xref: /aosp_15_r20/external/pytorch/test/cpp/api/nested_int.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/core/NestedIntSymNodeImpl.h>
4 #include <c10/core/SymInt.h>
5 #include <c10/core/SymNodeImpl.h>
6 #include <torch/torch.h>
7 
8 #include <test/cpp/api/support.h>
9 
TEST(NestedIntTest,Comparisons)10 TEST(NestedIntTest, Comparisons) {
11   auto a = c10::SymInt(
12       c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
13   auto b = c10::SymInt(
14       c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
15   auto c = c10::SymInt(
16       c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
17   auto d = c10::SymInt(3);
18 
19   ASSERT_TRUE(a == a);
20   ASSERT_TRUE(a == b);
21   ASSERT_FALSE(a != a);
22   ASSERT_FALSE(a != b);
23   ASSERT_FALSE(a == c);
24   ASSERT_TRUE(a != c);
25 
26   ASSERT_FALSE(a == d);
27   ASSERT_TRUE(a != d);
28   ASSERT_FALSE(d == a);
29   ASSERT_TRUE(d != a);
30 
31   // ge
32   ASSERT_TRUE(a >= a);
33   ASSERT_TRUE(a >= b);
34   ASSERT_TRUE(b >= a);
35   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
36   EXPECT_THROW((void)(a >= c), c10::Error);
37   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
38   EXPECT_THROW((void)(c >= a), c10::Error);
39   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
40   EXPECT_THROW((void)(c >= 3), c10::Error);
41   ASSERT_TRUE(c >= 2);
42   ASSERT_TRUE(c >= 1);
43   ASSERT_FALSE(1 >= c);
44 
45   // lt
46   ASSERT_FALSE(a < a);
47   ASSERT_FALSE(a < b);
48   ASSERT_FALSE(b < a);
49   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
50   EXPECT_THROW((void)(a < c), c10::Error);
51   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
52   EXPECT_THROW((void)(c < a), c10::Error);
53   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
54   EXPECT_THROW((void)(3 < a), c10::Error);
55   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
56   EXPECT_THROW((void)(2 < a), c10::Error);
57   ASSERT_TRUE(1 < a);
58 
59   // le
60   ASSERT_TRUE(a <= a);
61   ASSERT_TRUE(b <= a);
62   ASSERT_TRUE(a <= b);
63   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64   EXPECT_THROW((void)(a <= c), c10::Error);
65   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66   EXPECT_THROW((void)(c <= a), c10::Error);
67   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
68   EXPECT_THROW((void)(3 <= c), c10::Error);
69   ASSERT_TRUE(2 <= c);
70   ASSERT_TRUE(1 <= c);
71   ASSERT_FALSE(c <= 1);
72 
73   // gt
74   ASSERT_FALSE(a > a);
75   ASSERT_FALSE(b > a);
76   ASSERT_FALSE(a > b);
77   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
78   EXPECT_THROW((void)(a > c), c10::Error);
79   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
80   EXPECT_THROW((void)(c > a), c10::Error);
81   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
82   EXPECT_THROW((void)(a > 3), c10::Error);
83   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
84   EXPECT_THROW((void)(a > 2), c10::Error);
85   ASSERT_TRUE(a > 1);
86 }
87 
TEST(NestedIntTest,WithFactor)88 TEST(NestedIntTest, WithFactor) {
89   auto a = c10::SymInt(
90       c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
91   auto b = c10::SymInt(
92       c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
93   // eq
94   ASSERT_FALSE(a == b);
95   ASSERT_FALSE(a >= b);
96   ASSERT_TRUE(b >= a);
97   ASSERT_TRUE(a <= b);
98   ASSERT_FALSE(b <= a);
99   // ne
100   ASSERT_TRUE(a != b);
101   // mul
102   ASSERT_TRUE(a * 2 == b);
103   ASSERT_TRUE(a * 3 >= b);
104   ASSERT_TRUE(a * 2 == 2 * a);
105 }
106