xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/NamedTensor_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NamedTensorUtils.h>
5 #include <ATen/TensorNames.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/irange.h>
8 
9 using at::Dimname;
10 using at::DimnameList;
11 using at::Symbol;
12 using at::namedinference::TensorName;
13 using at::namedinference::TensorNames;
14 
dimnameFromString(const std::string & str)15 static Dimname dimnameFromString(const std::string& str) {
16   return Dimname::fromSymbol(Symbol::dimname(str));
17 }
18 
TEST(NamedTensorTest,isNamed)19 TEST(NamedTensorTest, isNamed) {
20   auto tensor = at::zeros({3, 2, 5, 7});
21   ASSERT_FALSE(tensor.has_names());
22 
23   tensor = at::zeros({3, 2, 5, 7});
24   ASSERT_FALSE(tensor.has_names());
25 
26   tensor = at::zeros({3, 2, 5, 7});
27   auto N = dimnameFromString("N");
28   auto C = dimnameFromString("C");
29   auto H = dimnameFromString("H");
30   auto W = dimnameFromString("W");
31   std::vector<Dimname> names = { N, C, H, W };
32   at::internal_set_names_inplace(tensor, names);
33   ASSERT_TRUE(tensor.has_names());
34 }
35 
dimnames_equal(at::DimnameList names,at::DimnameList other)36 static bool dimnames_equal(at::DimnameList names, at::DimnameList other) {
37   if (names.size() != other.size()) {
38     return false;
39   }
40   for (const auto i : c10::irange(names.size())) {
41     const auto& name = names[i];
42     const auto& other_name = other[i];
43     if (name.type() != other_name.type() || name.symbol() != other_name.symbol()) {
44       return false;
45     }
46   }
47   return true;
48 }
49 
TEST(NamedTensorTest,attachMetadata)50 TEST(NamedTensorTest, attachMetadata) {
51   auto tensor = at::zeros({3, 2, 5, 7});
52   auto N = dimnameFromString("N");
53   auto C = dimnameFromString("C");
54   auto H = dimnameFromString("H");
55   auto W = dimnameFromString("W");
56   std::vector<Dimname> names = { N, C, H, W };
57 
58   at::internal_set_names_inplace(tensor, names);
59 
60   const auto retrieved_meta = tensor.get_named_tensor_meta();
61   ASSERT_TRUE(dimnames_equal(retrieved_meta->names(), names));
62 
63   // Test dropping metadata
64   tensor.unsafeGetTensorImpl()->set_named_tensor_meta(nullptr);
65   ASSERT_FALSE(tensor.has_names());
66 }
67 
TEST(NamedTensorTest,internalSetNamesInplace)68 TEST(NamedTensorTest, internalSetNamesInplace) {
69   auto tensor = at::zeros({3, 2, 5, 7});
70   auto N = dimnameFromString("N");
71   auto C = dimnameFromString("C");
72   auto H = dimnameFromString("H");
73   auto W = dimnameFromString("W");
74   std::vector<Dimname> names = { N, C, H, W };
75   ASSERT_FALSE(tensor.has_names());
76 
77   // Set names
78   at::internal_set_names_inplace(tensor, names);
79   const auto retrieved_names = tensor.opt_names().value();
80   ASSERT_TRUE(dimnames_equal(retrieved_names, names));
81 
82   // Drop names
83   at::internal_set_names_inplace(tensor, std::nullopt);
84   ASSERT_TRUE(tensor.get_named_tensor_meta() == nullptr);
85   ASSERT_TRUE(tensor.opt_names() == std::nullopt);
86 }
87 
TEST(NamedTensorTest,empty)88 TEST(NamedTensorTest, empty) {
89   auto N = Dimname::fromSymbol(Symbol::dimname("N"));
90   auto C = Dimname::fromSymbol(Symbol::dimname("C"));
91   auto H = Dimname::fromSymbol(Symbol::dimname("H"));
92   auto W = Dimname::fromSymbol(Symbol::dimname("W"));
93   std::vector<Dimname> names = { N, C, H, W };
94 
95   auto tensor = at::empty({});
96   ASSERT_EQ(tensor.opt_names(), std::nullopt);
97 
98   tensor = at::empty({1, 2, 3});
99   ASSERT_EQ(tensor.opt_names(), std::nullopt);
100 
101   tensor = at::empty({1, 2, 3, 4}, names);
102   ASSERT_TRUE(dimnames_equal(tensor.opt_names().value(), names));
103 
104   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
105   ASSERT_THROW(at::empty({1, 2, 3}, names), c10::Error);
106 }
107 
TEST(NamedTensorTest,dimnameToPosition)108 TEST(NamedTensorTest, dimnameToPosition) {
109   auto N = dimnameFromString("N");
110   auto C = dimnameFromString("C");
111   auto H = dimnameFromString("H");
112   auto W = dimnameFromString("W");
113   std::vector<Dimname> names = { N, C, H, W };
114 
115   auto tensor = at::empty({1, 1, 1});
116   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
117   ASSERT_THROW(dimname_to_position(tensor, N), c10::Error);
118 
119   tensor = at::empty({1, 1, 1, 1}, names);
120   ASSERT_EQ(dimname_to_position(tensor, H), 2);
121 }
122 
tensornames_unify_from_right(DimnameList names,DimnameList other_names)123 static std::vector<Dimname> tensornames_unify_from_right(
124     DimnameList names,
125     DimnameList other_names) {
126   auto names_wrapper = at::namedinference::TensorNames(names);
127   auto other_wrapper = at::namedinference::TensorNames(other_names);
128   return names_wrapper.unifyFromRightInplace(other_wrapper).toDimnameVec();
129 }
130 
check_unify(DimnameList names,DimnameList other_names,DimnameList expected)131 static void check_unify(
132     DimnameList names,
133     DimnameList other_names,
134     DimnameList expected) {
135   // Check legacy at::unify_from_right
136   const auto result = at::unify_from_right(names, other_names);
137   ASSERT_TRUE(dimnames_equal(result, expected));
138 
139   // Check with TensorNames::unifyFromRight.
140   // In the future we'll merge at::unify_from_right and
141   // TensorNames::unifyFromRight, but for now, let's test them both.
142   const auto also_result = tensornames_unify_from_right(names, other_names);
143   ASSERT_TRUE(dimnames_equal(also_result, expected));
144 }
145 
check_unify_error(DimnameList names,DimnameList other_names)146 static void check_unify_error(DimnameList names, DimnameList other_names) {
147   // In the future we'll merge at::unify_from_right and
148   // TensorNames::unifyFromRight. For now, test them both.
149   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
150   ASSERT_THROW(at::unify_from_right(names, other_names), c10::Error);
151   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
152   ASSERT_THROW(tensornames_unify_from_right(names, other_names), c10::Error);
153 }
154 
TEST(NamedTensorTest,unifyFromRight)155 TEST(NamedTensorTest, unifyFromRight) {
156   auto N = dimnameFromString("N");
157   auto C = dimnameFromString("C");
158   auto H = dimnameFromString("H");
159   auto W = dimnameFromString("W");
160   auto None = dimnameFromString("*");
161 
162   std::vector<Dimname> names = { N, C };
163 
164   check_unify({ N, C, H, W }, { N, C, H, W }, { N, C, H, W });
165   check_unify({ W }, { C, H, W }, { C, H, W });
166   check_unify({ None, W }, { C, H, W }, { C, H, W });
167   check_unify({ None, None, H, None }, { C, None, W }, { None, C, H, W });
168 
169   check_unify_error({ W, H }, { W, C });
170   check_unify_error({ W, H }, { C, H });
171   check_unify_error({ None, H }, { H, None });
172   check_unify_error({ H, None, C }, { H });
173 }
174 
TEST(NamedTensorTest,alias)175 TEST(NamedTensorTest, alias) {
176   // tensor.alias is not exposed in Python so we test its name propagation here
177   auto N = dimnameFromString("N");
178   auto C = dimnameFromString("C");
179   std::vector<Dimname> names = { N, C };
180 
181   auto tensor = at::empty({2, 3}, std::vector<Dimname>{ N, C });
182   auto aliased = tensor.alias();
183   ASSERT_TRUE(dimnames_equal(tensor.opt_names().value(), aliased.opt_names().value()));
184 }
185 
TEST(NamedTensorTest,NoNamesGuard)186 TEST(NamedTensorTest, NoNamesGuard) {
187   auto N = dimnameFromString("N");
188   auto C = dimnameFromString("C");
189   std::vector<Dimname> names = { N, C };
190 
191   auto tensor = at::empty({2, 3}, names);
192   ASSERT_TRUE(at::NamesMode::is_enabled());
193   {
194     at::NoNamesGuard guard;
195     ASSERT_FALSE(at::NamesMode::is_enabled());
196     ASSERT_FALSE(tensor.opt_names());
197     ASSERT_FALSE(at::impl::get_opt_names(tensor.unsafeGetTensorImpl()));
198   }
199   ASSERT_TRUE(at::NamesMode::is_enabled());
200 }
201 
nchw()202 static std::vector<Dimname> nchw() {
203   auto N = dimnameFromString("N");
204   auto C = dimnameFromString("C");
205   auto H = dimnameFromString("H");
206   auto W = dimnameFromString("W");
207   return { N, C, H, W };
208 }
209 
TEST(NamedTensorTest,TensorNamePrint)210 TEST(NamedTensorTest, TensorNamePrint) {
211   auto names = nchw();
212   {
213     auto N = TensorName(names, 0);
214     ASSERT_EQ(
215         c10::str(N),
216         "'N' (index 0 of ['N', 'C', 'H', 'W'])");
217   }
218   {
219     auto H = TensorName(names, 2);
220     ASSERT_EQ(
221         c10::str(H),
222         "'H' (index 2 of ['N', 'C', 'H', 'W'])");
223   }
224 }
225 
TEST(NamedTensorTest,TensorNamesCheckUnique)226 TEST(NamedTensorTest, TensorNamesCheckUnique) {
227   auto names = nchw();
228   {
229     // smoke test to check that this doesn't throw
230     TensorNames(names).checkUnique("op_name");
231   }
232   {
233     std::vector<Dimname> nchh = { names[0], names[1], names[2], names[2] };
234     auto tensornames = TensorNames(nchh);
235     // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
236     ASSERT_THROW(tensornames.checkUnique("op_name"), c10::Error);
237   }
238 }
239