1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/TensorNames.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/irange.h>
8
9 using at::Dimname;
10 using at::DimnameList;
11 using at::Symbol;
12 using at::namedinference::TensorName;
13 using at::namedinference::TensorNames;
14
dimnameFromString(const std::string & str)15 static Dimname dimnameFromString(const std::string& str) {
16 return Dimname::fromSymbol(Symbol::dimname(str));
17 }
18
TEST(NamedTensorTest,isNamed)19 TEST(NamedTensorTest, isNamed) {
20 auto tensor = at::zeros({3, 2, 5, 7});
21 ASSERT_FALSE(tensor.has_names());
22
23 tensor = at::zeros({3, 2, 5, 7});
24 ASSERT_FALSE(tensor.has_names());
25
26 tensor = at::zeros({3, 2, 5, 7});
27 auto N = dimnameFromString("N");
28 auto C = dimnameFromString("C");
29 auto H = dimnameFromString("H");
30 auto W = dimnameFromString("W");
31 std::vector<Dimname> names = { N, C, H, W };
32 at::internal_set_names_inplace(tensor, names);
33 ASSERT_TRUE(tensor.has_names());
34 }
35
dimnames_equal(at::DimnameList names,at::DimnameList other)36 static bool dimnames_equal(at::DimnameList names, at::DimnameList other) {
37 if (names.size() != other.size()) {
38 return false;
39 }
40 for (const auto i : c10::irange(names.size())) {
41 const auto& name = names[i];
42 const auto& other_name = other[i];
43 if (name.type() != other_name.type() || name.symbol() != other_name.symbol()) {
44 return false;
45 }
46 }
47 return true;
48 }
49
TEST(NamedTensorTest,attachMetadata)50 TEST(NamedTensorTest, attachMetadata) {
51 auto tensor = at::zeros({3, 2, 5, 7});
52 auto N = dimnameFromString("N");
53 auto C = dimnameFromString("C");
54 auto H = dimnameFromString("H");
55 auto W = dimnameFromString("W");
56 std::vector<Dimname> names = { N, C, H, W };
57
58 at::internal_set_names_inplace(tensor, names);
59
60 const auto retrieved_meta = tensor.get_named_tensor_meta();
61 ASSERT_TRUE(dimnames_equal(retrieved_meta->names(), names));
62
63 // Test dropping metadata
64 tensor.unsafeGetTensorImpl()->set_named_tensor_meta(nullptr);
65 ASSERT_FALSE(tensor.has_names());
66 }
67
TEST(NamedTensorTest,internalSetNamesInplace)68 TEST(NamedTensorTest, internalSetNamesInplace) {
69 auto tensor = at::zeros({3, 2, 5, 7});
70 auto N = dimnameFromString("N");
71 auto C = dimnameFromString("C");
72 auto H = dimnameFromString("H");
73 auto W = dimnameFromString("W");
74 std::vector<Dimname> names = { N, C, H, W };
75 ASSERT_FALSE(tensor.has_names());
76
77 // Set names
78 at::internal_set_names_inplace(tensor, names);
79 const auto retrieved_names = tensor.opt_names().value();
80 ASSERT_TRUE(dimnames_equal(retrieved_names, names));
81
82 // Drop names
83 at::internal_set_names_inplace(tensor, std::nullopt);
84 ASSERT_TRUE(tensor.get_named_tensor_meta() == nullptr);
85 ASSERT_TRUE(tensor.opt_names() == std::nullopt);
86 }
87
TEST(NamedTensorTest,empty)88 TEST(NamedTensorTest, empty) {
89 auto N = Dimname::fromSymbol(Symbol::dimname("N"));
90 auto C = Dimname::fromSymbol(Symbol::dimname("C"));
91 auto H = Dimname::fromSymbol(Symbol::dimname("H"));
92 auto W = Dimname::fromSymbol(Symbol::dimname("W"));
93 std::vector<Dimname> names = { N, C, H, W };
94
95 auto tensor = at::empty({});
96 ASSERT_EQ(tensor.opt_names(), std::nullopt);
97
98 tensor = at::empty({1, 2, 3});
99 ASSERT_EQ(tensor.opt_names(), std::nullopt);
100
101 tensor = at::empty({1, 2, 3, 4}, names);
102 ASSERT_TRUE(dimnames_equal(tensor.opt_names().value(), names));
103
104 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
105 ASSERT_THROW(at::empty({1, 2, 3}, names), c10::Error);
106 }
107
TEST(NamedTensorTest,dimnameToPosition)108 TEST(NamedTensorTest, dimnameToPosition) {
109 auto N = dimnameFromString("N");
110 auto C = dimnameFromString("C");
111 auto H = dimnameFromString("H");
112 auto W = dimnameFromString("W");
113 std::vector<Dimname> names = { N, C, H, W };
114
115 auto tensor = at::empty({1, 1, 1});
116 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
117 ASSERT_THROW(dimname_to_position(tensor, N), c10::Error);
118
119 tensor = at::empty({1, 1, 1, 1}, names);
120 ASSERT_EQ(dimname_to_position(tensor, H), 2);
121 }
122
tensornames_unify_from_right(DimnameList names,DimnameList other_names)123 static std::vector<Dimname> tensornames_unify_from_right(
124 DimnameList names,
125 DimnameList other_names) {
126 auto names_wrapper = at::namedinference::TensorNames(names);
127 auto other_wrapper = at::namedinference::TensorNames(other_names);
128 return names_wrapper.unifyFromRightInplace(other_wrapper).toDimnameVec();
129 }
130
check_unify(DimnameList names,DimnameList other_names,DimnameList expected)131 static void check_unify(
132 DimnameList names,
133 DimnameList other_names,
134 DimnameList expected) {
135 // Check legacy at::unify_from_right
136 const auto result = at::unify_from_right(names, other_names);
137 ASSERT_TRUE(dimnames_equal(result, expected));
138
139 // Check with TensorNames::unifyFromRight.
140 // In the future we'll merge at::unify_from_right and
141 // TensorNames::unifyFromRight, but for now, let's test them both.
142 const auto also_result = tensornames_unify_from_right(names, other_names);
143 ASSERT_TRUE(dimnames_equal(also_result, expected));
144 }
145
check_unify_error(DimnameList names,DimnameList other_names)146 static void check_unify_error(DimnameList names, DimnameList other_names) {
147 // In the future we'll merge at::unify_from_right and
148 // TensorNames::unifyFromRight. For now, test them both.
149 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
150 ASSERT_THROW(at::unify_from_right(names, other_names), c10::Error);
151 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
152 ASSERT_THROW(tensornames_unify_from_right(names, other_names), c10::Error);
153 }
154
TEST(NamedTensorTest,unifyFromRight)155 TEST(NamedTensorTest, unifyFromRight) {
156 auto N = dimnameFromString("N");
157 auto C = dimnameFromString("C");
158 auto H = dimnameFromString("H");
159 auto W = dimnameFromString("W");
160 auto None = dimnameFromString("*");
161
162 std::vector<Dimname> names = { N, C };
163
164 check_unify({ N, C, H, W }, { N, C, H, W }, { N, C, H, W });
165 check_unify({ W }, { C, H, W }, { C, H, W });
166 check_unify({ None, W }, { C, H, W }, { C, H, W });
167 check_unify({ None, None, H, None }, { C, None, W }, { None, C, H, W });
168
169 check_unify_error({ W, H }, { W, C });
170 check_unify_error({ W, H }, { C, H });
171 check_unify_error({ None, H }, { H, None });
172 check_unify_error({ H, None, C }, { H });
173 }
174
TEST(NamedTensorTest,alias)175 TEST(NamedTensorTest, alias) {
176 // tensor.alias is not exposed in Python so we test its name propagation here
177 auto N = dimnameFromString("N");
178 auto C = dimnameFromString("C");
179 std::vector<Dimname> names = { N, C };
180
181 auto tensor = at::empty({2, 3}, std::vector<Dimname>{ N, C });
182 auto aliased = tensor.alias();
183 ASSERT_TRUE(dimnames_equal(tensor.opt_names().value(), aliased.opt_names().value()));
184 }
185
TEST(NamedTensorTest,NoNamesGuard)186 TEST(NamedTensorTest, NoNamesGuard) {
187 auto N = dimnameFromString("N");
188 auto C = dimnameFromString("C");
189 std::vector<Dimname> names = { N, C };
190
191 auto tensor = at::empty({2, 3}, names);
192 ASSERT_TRUE(at::NamesMode::is_enabled());
193 {
194 at::NoNamesGuard guard;
195 ASSERT_FALSE(at::NamesMode::is_enabled());
196 ASSERT_FALSE(tensor.opt_names());
197 ASSERT_FALSE(at::impl::get_opt_names(tensor.unsafeGetTensorImpl()));
198 }
199 ASSERT_TRUE(at::NamesMode::is_enabled());
200 }
201
nchw()202 static std::vector<Dimname> nchw() {
203 auto N = dimnameFromString("N");
204 auto C = dimnameFromString("C");
205 auto H = dimnameFromString("H");
206 auto W = dimnameFromString("W");
207 return { N, C, H, W };
208 }
209
TEST(NamedTensorTest,TensorNamePrint)210 TEST(NamedTensorTest, TensorNamePrint) {
211 auto names = nchw();
212 {
213 auto N = TensorName(names, 0);
214 ASSERT_EQ(
215 c10::str(N),
216 "'N' (index 0 of ['N', 'C', 'H', 'W'])");
217 }
218 {
219 auto H = TensorName(names, 2);
220 ASSERT_EQ(
221 c10::str(H),
222 "'H' (index 2 of ['N', 'C', 'H', 'W'])");
223 }
224 }
225
TEST(NamedTensorTest,TensorNamesCheckUnique)226 TEST(NamedTensorTest, TensorNamesCheckUnique) {
227 auto names = nchw();
228 {
229 // smoke test to check that this doesn't throw
230 TensorNames(names).checkUnique("op_name");
231 }
232 {
233 std::vector<Dimname> nchh = { names[0], names[1], names[2], names[2] };
234 auto tensornames = TensorNames(nchh);
235 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
236 ASSERT_THROW(tensornames.checkUnique("op_name"), c10::Error);
237 }
238 }
239