xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_union.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 class UnionTypeTest : public ::testing::Test {
11  public:
12   // None
13   const TypePtr none = NoneType::get();
14 
15   // List[str]
16   const TypePtr l1 = ListType::ofStrings();
17 
18   // Optional[int]
19   const TypePtr opt1 = OptionalType::create(IntType::get());
20 
21   // Optional[float]
22   const TypePtr opt2 = OptionalType::create(FloatType::get());
23 
24   // Optional[List[str]]
25   const TypePtr opt3 = OptionalType::create(ListType::ofStrings());
26 
27   // Tuple[Optional[int], int]
28   const TypePtr tup1 =
29       TupleType::create({OptionalType::create(IntType::get()), IntType::get()});
30 
31   // Tuple[int, int]
32   const TypePtr tup2 = TupleType::create({IntType::get(), IntType::get()});
33 
hasType(UnionTypePtr u,TypePtr t)34   bool hasType(UnionTypePtr u, TypePtr t) {
35     auto res = std::find(u->getTypes().begin(), u->getTypes().end(), t);
36     return res != u->getTypes().end();
37   }
38 };
39 
TEST_F(UnionTypeTest,UnionOperatorEquals)40 TEST_F(UnionTypeTest, UnionOperatorEquals) {
41   const UnionTypePtr u1 = UnionType::create({l1, tup2, StringType::get()});
42 
43   // Same thing, but using different TypePtrs
44   const TypePtr l1_ = ListType::ofStrings();
45   const TypePtr tup2_ = TupleType::create({IntType::get(), IntType::get()});
46   const UnionTypePtr u2 = UnionType::create({l1_, tup2_, StringType::get()});
47 
48   ASSERT_TRUE(*u1 == *u2);
49 }
50 
TEST_F(UnionTypeTest,UnionCreate_OptionalT1AndOptionalT2)51 TEST_F(UnionTypeTest, UnionCreate_OptionalT1AndOptionalT2) {
52   // Goal: Union[int, float, None]
53   const UnionTypePtr u = UnionType::create({opt1, opt2});
54 
55   ASSERT_EQ(u->getTypes().size(), 3);
56   ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
57   ASSERT_TRUE(UnionTypeTest::hasType(u, FloatType::get()));
58   ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
59 }
60 
TEST_F(UnionTypeTest,UnionCreate_OptionalTAndT)61 TEST_F(UnionTypeTest, UnionCreate_OptionalTAndT) {
62   // Goal: Union[int, None]
63   const UnionTypePtr u = UnionType::create({opt1, IntType::get()});
64 
65   ASSERT_EQ(u->getTypes().size(), 2);
66   ASSERT_TRUE(UnionTypeTest::hasType(u, IntType::get()));
67   ASSERT_TRUE(UnionTypeTest::hasType(u, NoneType::get()));
68 }
69 
TEST_F(UnionTypeTest,UnionCreate_TupleWithSubtypingRelationship)70 TEST_F(UnionTypeTest, UnionCreate_TupleWithSubtypingRelationship) {
71   // Goal: Union[Tuple[Optional[int], int], str]
72   const UnionTypePtr u = UnionType::create({StringType::get(), tup1, tup2});
73 
74   ASSERT_EQ(u->getTypes().size(), 2);
75   ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
76   ASSERT_TRUE(UnionTypeTest::hasType(u, tup1));
77 }
78 
TEST_F(UnionTypeTest,UnionCreate_ContainerTAndT)79 TEST_F(UnionTypeTest, UnionCreate_ContainerTAndT) {
80   // Goal: Union[List[str], str]
81   const UnionTypePtr u = UnionType::create({l1, StringType::get()});
82 
83   ASSERT_EQ(u->getTypes().size(), 2);
84   ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
85   ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
86 }
87 
TEST_F(UnionTypeTest,UnionCreate_OptionalContainerTAndContainerTAndT)88 TEST_F(UnionTypeTest, UnionCreate_OptionalContainerTAndContainerTAndT) {
89   // Goal: Union[List[str], None, str]
90   const UnionTypePtr u = UnionType::create({l1, opt3, StringType::get()});
91 
92   ASSERT_EQ(u->getTypes().size(), 3);
93   ASSERT_TRUE(UnionTypeTest::hasType(u, StringType::get()));
94   ASSERT_TRUE(UnionTypeTest::hasType(u, ListType::ofStrings()));
95 }
96 
TEST_F(UnionTypeTest,Subtyping_NumberType)97 TEST_F(UnionTypeTest, Subtyping_NumberType) {
98   // Union[int, float, Complex]
99   const UnionTypePtr union1 =
100       UnionType::create({IntType::get(), FloatType::get(), ComplexType::get()});
101 
102   // Union[int, float, Complex, None]
103   const UnionTypePtr union2 = UnionType::create(
104       {IntType::get(), FloatType::get(), ComplexType::get(), NoneType::get()});
105 
106   const NumberTypePtr num = NumberType::get();
107 
108   ASSERT_TRUE(num->isSubtypeOf(*union1));
109   ASSERT_TRUE(union1->isSubtypeOf(*num));
110   ASSERT_TRUE(*num == *union1);
111 
112   ASSERT_TRUE(num->isSubtypeOf(*union2));
113   ASSERT_FALSE(union2->isSubtypeOf(*num));
114   ASSERT_FALSE(*num == *union2);
115 }
116 
TEST_F(UnionTypeTest,Subtyping_OptionalType)117 TEST_F(UnionTypeTest, Subtyping_OptionalType) {
118   // Union[int, None]
119   const UnionTypePtr union1 =
120       UnionType::create({IntType::get(), NoneType::get()});
121 
122   // Union[int, str, None]
123   const UnionTypePtr union2 =
124       UnionType::create({IntType::get(), StringType::get(), NoneType::get()});
125 
126   // Union[int, str, List[str]]
127   const UnionTypePtr union3 = UnionType::create(
128       {IntType::get(), StringType::get(), ListType::ofStrings()});
129 
130   ASSERT_TRUE(none->isSubtypeOf(opt1));
131   ASSERT_TRUE(none->isSubtypeOf(union1));
132   ASSERT_TRUE(none->isSubtypeOf(union2));
133   ASSERT_FALSE(none->isSubtypeOf(union3));
134 
135   ASSERT_FALSE(opt1->isSubtypeOf(none));
136   ASSERT_TRUE(opt1->isSubtypeOf(union1));
137   ASSERT_TRUE(opt1->isSubtypeOf(union2));
138   ASSERT_FALSE(opt1->isSubtypeOf(union3));
139 
140   ASSERT_FALSE(union1->isSubtypeOf(none));
141   ASSERT_TRUE(union1->isSubtypeOf(opt1));
142   ASSERT_TRUE(union1->isSubtypeOf(union2));
143   ASSERT_FALSE(union1->isSubtypeOf(union3));
144 
145   ASSERT_FALSE(union2->isSubtypeOf(union1));
146 }
147 
148 } // namespace jit
149 } // namespace torch
150