xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17 
18 #include "tensorflow/core/framework/dataset_metadata.pb.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/graph/node_builder.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_utils {
27 namespace {
28 
29 using test::function::NDef;
30 
31 constexpr char kOutputShapes[] = "output_shapes";
32 constexpr char kOutputTypes[] = "output_types";
33 constexpr char kToutputTypes[] = "Toutput_types";
34 
TEST(GraphUtilsTest,GetFirstElementIndexWithPredicate)35 TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
36   std::vector<int> vec({1, 2, 3, 4, 5, 6});
37   auto result = GetFirstElementIndexWithPredicate(
38       [](int elem) { return elem % 3 == 0; }, vec);
39 
40   EXPECT_EQ(result, 2);
41 
42   result = GetFirstElementIndexWithPredicate(
43       [](int elem) { return elem % 7 == 0; }, vec);
44   EXPECT_EQ(result, -1);
45 }
46 
TEST(GraphUtilsTest,AddScalarConstNodeBool)47 TEST(GraphUtilsTest, AddScalarConstNodeBool) {
48   GraphDef graph_def;
49   MutableGraphView graph(&graph_def);
50   NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
51   EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
52   EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
53 }
54 
TEST(GraphUtilsTest,AddScalarConstNodeDouble)55 TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
56   GraphDef graph_def;
57   MutableGraphView graph(&graph_def);
58   NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
59   EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
60   EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
61 }
62 
TEST(GraphUtilsTest,AddScalarConstNodeFloat)63 TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
64   GraphDef graph_def;
65   MutableGraphView graph(&graph_def);
66   NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
67   EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
68   EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
69 }
70 
TEST(GraphUtilsTest,AddScalarConstNodeInt)71 TEST(GraphUtilsTest, AddScalarConstNodeInt) {
72   GraphDef graph_def;
73   MutableGraphView graph(&graph_def);
74   NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
75   EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
76   EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
77 }
78 
TEST(GraphUtilsTest,AddScalarConstNodeInt64)79 TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
80   GraphDef graph_def;
81   MutableGraphView graph(&graph_def);
82   NodeDef* int64_node = AddScalarConstNode<int64_t>(42, &graph);
83   EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
84   EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
85 }
86 
TEST(GraphUtilsTest,AddScalarConstNodeString)87 TEST(GraphUtilsTest, AddScalarConstNodeString) {
88   GraphDef graph_def;
89   MutableGraphView graph(&graph_def);
90   NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
91   EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
92   EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
93 }
94 
TEST(GraphUtilsTest,GetScalarConstNodeInt64)95 TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
96   GraphDef graph_def;
97   MutableGraphView graph(&graph_def);
98   NodeDef* int64_node = AddScalarConstNode<int64_t>(128, &graph);
99   int64_t result;
100   EXPECT_TRUE(GetScalarConstNodeValue<int64_t>(*int64_node, &result).ok());
101   EXPECT_EQ(result, 128);
102 }
103 
TEST(GraphUtilsTest,GetScalarConstNodeBool)104 TEST(GraphUtilsTest, GetScalarConstNodeBool) {
105   GraphDef graph_def;
106   MutableGraphView graph(&graph_def);
107   NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
108   bool result;
109   EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
110   EXPECT_EQ(result, true);
111 }
112 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithNonConst)113 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
114   GraphDef graph_def;
115   MutableGraphView graph(&graph_def);
116   NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
117   int64_t result;
118   Status s = GetScalarConstNodeValue<int64_t>(*non_const, &result);
119   EXPECT_FALSE(s.ok());
120   EXPECT_EQ(s.error_message(),
121             "Node Placeholder is not a Const node. Op: Placeholder");
122 }
123 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithType)124 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
125   GraphDef graph_def;
126   MutableGraphView graph(&graph_def);
127   NodeDef* int64_node = AddScalarConstNode<int64_t>(128, &graph);
128   bool result;
129   Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
130   EXPECT_FALSE(s.ok());
131   EXPECT_EQ(s.error_message(),
132             "Node Const should have type bool but has type: int64");
133 }
134 
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithVector)135 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
136   NodeDef node;
137   node.set_name("Const");
138   node.set_op("Const");
139 
140   (*node.mutable_attr())["dtype"].set_type(DT_INT64);
141   auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
142   tensor->set_dtype(DT_INT64);
143   tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
144   tensor->add_int64_val(128);
145 
146   int64_t result;
147   Status s = GetScalarConstNodeValue<int64_t>(node, &result);
148   EXPECT_FALSE(s.ok());
149   EXPECT_EQ(s.error_message(),
150             "Node Const should be a scalar but has shape: [1]");
151 }
152 
TEST(GraphUtilsTest,Compare)153 TEST(GraphUtilsTest, Compare) {
154   GraphDef graph_def_a;
155   MutableGraphView graph_a(&graph_def_a);
156   GraphDef graph_def_b;
157   MutableGraphView graph_b(&graph_def_b);
158 
159   EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
160 
161   AddNode("A", "OpA", {}, {}, &graph_a);
162   AddNode("B", "OpB", {"A"}, {}, &graph_a);
163   EXPECT_FALSE(Compare(graph_def_a, graph_def_b));
164 
165   graph_def_b.mutable_node()->CopyFrom(graph_def_a.node());
166   EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
167 }
168 
TEST(GraphUtilsTest,ContainsGraphNodeWithName)169 TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
170   GraphDef graph_def;
171   MutableGraphView graph(&graph_def);
172   EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
173 
174   AddNode("A", "OpA", {}, {}, &graph);
175   EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
176 
177   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
178   EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
179 }
180 
TEST(GraphUtilsTest,ContainsGraphFunctionWithName)181 TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
182   FunctionDefLibrary library;
183   EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library));
184   FunctionDef* new_function = library.add_function();
185   SetUniqueGraphFunctionName("new_function", &library, new_function);
186 
187   EXPECT_TRUE(
188       ContainsGraphFunctionWithName(new_function->signature().name(), library));
189 }
190 
TEST(GraphUtilsTest,ContainsNodeWithOp)191 TEST(GraphUtilsTest, ContainsNodeWithOp) {
192   GraphDef graph_def;
193   MutableGraphView graph(&graph_def);
194   EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
195 
196   AddNode("A", "OpA", {}, {}, &graph);
197   EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
198 
199   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
200   EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
201 }
202 
TEST(GraphUtilsTest,FindGraphNodeWithName)203 TEST(GraphUtilsTest, FindGraphNodeWithName) {
204   GraphDef graph_def;
205   MutableGraphView graph(&graph_def);
206   EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
207 
208   AddNode("A", "OpA", {}, {}, &graph);
209   EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
210 
211   EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
212   EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
213 }
214 
TEST(GraphUtilsTest,FindGraphFunctionWithName)215 TEST(GraphUtilsTest, FindGraphFunctionWithName) {
216   FunctionDefLibrary library;
217   EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
218   FunctionDef* new_function = library.add_function();
219   SetUniqueGraphFunctionName("new_function", &library, new_function);
220 
221   EXPECT_NE(
222       FindGraphFunctionWithName(new_function->signature().name(), library), -1);
223 }
224 
TEST(GraphUtilsTest,FindGraphNodeWithOp)225 TEST(GraphUtilsTest, FindGraphNodeWithOp) {
226   GraphDef graph_def;
227   MutableGraphView graph(&graph_def);
228   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
229 
230   AddNode("A", "OpA", {}, {}, &graph);
231   AddNode("B", "OpB", {"A"}, {}, &graph);
232   AddNode("A2", "OpA", {"A"}, {}, &graph);
233   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
234 
235   EXPECT_TRUE(graph.DeleteNodes({"B"}).ok());
236   EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
237   EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
238 }
239 
TEST(GraphUtilsTest,FindAllGraphNodesWithOp)240 TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
241   GraphDef graph_def;
242   MutableGraphView graph(&graph_def);
243   EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
244 
245   AddNode("A", "OpA", {}, {}, &graph);
246   AddNode("B", "OpB", {"A"}, {}, &graph);
247   AddNode("A2", "OpA", {"B"}, {}, &graph);
248   std::vector<int> result_indices =
249       FindAllGraphNodesWithOp("OpA", *graph.graph());
250   EXPECT_EQ(result_indices.size(), 2);
251   EXPECT_EQ(result_indices.at(0), 0);
252   EXPECT_EQ(result_indices.at(1), 2);
253 
254   EXPECT_TRUE(graph.DeleteNodes({"A2"}).ok());
255   std::vector<int> result_indices_new =
256       FindAllGraphNodesWithOp("OpA", *graph.graph());
257   EXPECT_EQ(result_indices_new.size(), 1);
258   EXPECT_EQ(result_indices_new.at(0), 0);
259 }
260 
TEST(GraphUtilsTest,SetUniqueGraphNodeName)261 TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
262   GraphDef graph_def;
263   MutableGraphView graph(&graph_def);
264 
265   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
266   NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
267   EXPECT_NE(node1->name(), node2->name());
268 
269   EXPECT_TRUE(graph.DeleteNodes({node1->name()}).ok());
270   NodeDef* node3 = AddNode("", "A", {}, {}, &graph);
271   EXPECT_NE(node2->name(), node3->name());
272 }
273 
TEST(GraphUtilsTest,SetUniqueGraphFunctionName)274 TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
275   FunctionDefLibrary library;
276   FunctionDef* new_function = library.add_function();
277   SetUniqueGraphFunctionName("new_function", &library, new_function);
278 
279   FunctionDef* other_function = library.add_function();
280   SetUniqueGraphFunctionName("new_function", &library, other_function);
281   EXPECT_NE(new_function->signature().name(),
282             other_function->signature().name());
283 }
284 
TEST(GraphUtilsTest,GetInputNode)285 TEST(GraphUtilsTest, GetInputNode) {
286   GraphDef graph_def;
287   MutableGraphView graph(&graph_def);
288 
289   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
290   NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
291 
292   EXPECT_EQ(GetInputNode(*node2, graph), node1);
293   EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
294 }
295 
TEST(GraphUtilsTest,GetIthInputNode)296 TEST(GraphUtilsTest, GetIthInputNode) {
297   GraphDef graph_def;
298   MutableGraphView graph(&graph_def);
299 
300   NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
301   NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
302   NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
303 
304   EXPECT_EQ(GetInputNode(*node3, graph), node1);
305   EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
306   EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
307   EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
308   EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
309 }
310 
TEST(GraphUtilsTest,EnsureNodeNamesUnique)311 TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
312   Graph g(OpRegistry::Global());
313 
314   Node *const_0, *const_1, *const_2;
315 
316   // Arbitrary const
317   Tensor tensor(DT_INT32, {});
318   tensor.scalar<int32>()() = 5;
319 
320   for (auto node : {&const_0, &const_1}) {
321     TF_EXPECT_OK(NodeBuilder("Const", "Const")
322                      .Attr("value", tensor)
323                      .Attr("dtype", DT_INT32)
324                      .Finalize(&g, node));
325   }
326   // Make sure generated name doesn't clash with existing name either
327   TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
328                    .Attr("value", tensor)
329                    .Attr("dtype", DT_INT32)
330                    .Finalize(&g, &const_2));
331 
332   TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
333   EXPECT_NE(const_0->name(), const_1->name());
334   EXPECT_NE(const_1->name(), const_2->name());
335   EXPECT_NE(const_0->name(), const_2->name());
336 }
337 
TEST(GraphUtilsTest,TestGetFetchNode)338 TEST(GraphUtilsTest, TestGetFetchNode) {
339   GrapplerItem item;
340   MutableGraphView graph(&item.graph);
341 
342   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
343   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
344   NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
345   item.fetch.push_back(node3->name());
346 
347   NodeDef* sink_node;
348   TF_EXPECT_OK(GetFetchNode(graph, item, &sink_node));
349   EXPECT_EQ(sink_node->name(), node3->name());
350 }
351 
TEST(GraphUtilsTest,TestFindSinkNodeMultipleFetches)352 TEST(GraphUtilsTest, TestFindSinkNodeMultipleFetches) {
353   GrapplerItem item;
354   MutableGraphView graph(&item.graph);
355 
356   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
357   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
358   NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
359   item.fetch.push_back(node2->name());
360   item.fetch.push_back(node3->name());
361 
362   NodeDef* sink_node;
363   Status s = GetFetchNode(graph, item, &sink_node);
364   EXPECT_FALSE(s.ok());
365 }
366 
TEST(GraphUtilsTest,TestFindSinkNodeNoFetches)367 TEST(GraphUtilsTest, TestFindSinkNodeNoFetches) {
368   GrapplerItem item;
369   MutableGraphView graph(&item.graph);
370 
371   NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
372   NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
373   AddNode("node3", "Identity", {node2->name()}, {}, &graph);
374 
375   NodeDef* sink_node;
376   Status s = GetFetchNode(graph, item, &sink_node);
377   EXPECT_FALSE(s.ok());
378 }
379 
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsNoShapes)380 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsNoShapes) {
381   NodeDef from = NDef("range", "RangeDataset", {},
382                       {{kOutputTypes, gtl::ArraySlice<DataType>{}}});
383   NodeDef to_node;
384   EXPECT_FALSE(CopyShapesAndTypesAttrs(from, &to_node));
385 }
386 
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsNoTypes)387 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsNoTypes) {
388   NodeDef from = NDef("range", "RangeDataset", {},
389                       {{kOutputShapes, gtl::ArraySlice<TensorShape>{}}});
390   NodeDef to_node;
391   EXPECT_FALSE(CopyShapesAndTypesAttrs(from, &to_node));
392 }
393 
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsOutputTypes)394 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsOutputTypes) {
395   NodeDef from = NDef("range", "RangeDataset", {},
396                       {{kOutputShapes, 666}, {kOutputTypes, 888}});
397   NodeDef to_node;
398   EXPECT_TRUE(CopyShapesAndTypesAttrs(from, &to_node));
399   EXPECT_EQ(to_node.attr().at(kOutputShapes).i(), 666);
400   EXPECT_EQ(to_node.attr().at(kOutputTypes).i(), 888);
401 }
402 
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsToutputTypes)403 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsToutputTypes) {
404   NodeDef from = NDef("tensor", "TensorDataset", {},
405                       {{kOutputShapes, 666}, {kToutputTypes, 888}});
406   NodeDef to_node;
407   EXPECT_TRUE(CopyShapesAndTypesAttrs(from, &to_node));
408   EXPECT_EQ(to_node.attr().at(kOutputShapes).i(), 666);
409   EXPECT_EQ(to_node.attr().at(kOutputTypes).i(), 888);
410 }
411 
TEST(GraphUtilsTest,TestSetMetadataName)412 TEST(GraphUtilsTest, TestSetMetadataName) {
413   NodeDef node = NDef("range", "RangeDataset", {},
414                       {{kOutputShapes, 666}, {kOutputTypes, 888}});
415   EXPECT_TRUE(SetMetadataName("metadata_name", &node).ok());
416   EXPECT_TRUE(node.attr().contains("metadata"));
417   data::Metadata metadata;
418   metadata.ParseFromString(node.attr().at("metadata").s());
419   EXPECT_EQ("metadata_name", metadata.name());
420   EXPECT_FALSE(SetMetadataName("new_metadata_name", &node).ok());
421 }
422 
423 }  // namespace
424 }  // namespace graph_utils
425 }  // namespace grappler
426 }  // namespace tensorflow
427