1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17
18 #include "tensorflow/core/framework/dataset_metadata.pb.h"
19 #include "tensorflow/core/framework/function_testlib.h"
20 #include "tensorflow/core/graph/node_builder.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25 namespace grappler {
26 namespace graph_utils {
27 namespace {
28
29 using test::function::NDef;
30
31 constexpr char kOutputShapes[] = "output_shapes";
32 constexpr char kOutputTypes[] = "output_types";
33 constexpr char kToutputTypes[] = "Toutput_types";
34
TEST(GraphUtilsTest,GetFirstElementIndexWithPredicate)35 TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
36 std::vector<int> vec({1, 2, 3, 4, 5, 6});
37 auto result = GetFirstElementIndexWithPredicate(
38 [](int elem) { return elem % 3 == 0; }, vec);
39
40 EXPECT_EQ(result, 2);
41
42 result = GetFirstElementIndexWithPredicate(
43 [](int elem) { return elem % 7 == 0; }, vec);
44 EXPECT_EQ(result, -1);
45 }
46
TEST(GraphUtilsTest,AddScalarConstNodeBool)47 TEST(GraphUtilsTest, AddScalarConstNodeBool) {
48 GraphDef graph_def;
49 MutableGraphView graph(&graph_def);
50 NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
51 EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
52 EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
53 }
54
TEST(GraphUtilsTest,AddScalarConstNodeDouble)55 TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
56 GraphDef graph_def;
57 MutableGraphView graph(&graph_def);
58 NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
59 EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
60 EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
61 }
62
TEST(GraphUtilsTest,AddScalarConstNodeFloat)63 TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
64 GraphDef graph_def;
65 MutableGraphView graph(&graph_def);
66 NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
67 EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
68 EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
69 }
70
TEST(GraphUtilsTest,AddScalarConstNodeInt)71 TEST(GraphUtilsTest, AddScalarConstNodeInt) {
72 GraphDef graph_def;
73 MutableGraphView graph(&graph_def);
74 NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
75 EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
76 EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
77 }
78
TEST(GraphUtilsTest,AddScalarConstNodeInt64)79 TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
80 GraphDef graph_def;
81 MutableGraphView graph(&graph_def);
82 NodeDef* int64_node = AddScalarConstNode<int64_t>(42, &graph);
83 EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
84 EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
85 }
86
TEST(GraphUtilsTest,AddScalarConstNodeString)87 TEST(GraphUtilsTest, AddScalarConstNodeString) {
88 GraphDef graph_def;
89 MutableGraphView graph(&graph_def);
90 NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
91 EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
92 EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
93 }
94
TEST(GraphUtilsTest,GetScalarConstNodeInt64)95 TEST(GraphUtilsTest, GetScalarConstNodeInt64) {
96 GraphDef graph_def;
97 MutableGraphView graph(&graph_def);
98 NodeDef* int64_node = AddScalarConstNode<int64_t>(128, &graph);
99 int64_t result;
100 EXPECT_TRUE(GetScalarConstNodeValue<int64_t>(*int64_node, &result).ok());
101 EXPECT_EQ(result, 128);
102 }
103
TEST(GraphUtilsTest,GetScalarConstNodeBool)104 TEST(GraphUtilsTest, GetScalarConstNodeBool) {
105 GraphDef graph_def;
106 MutableGraphView graph(&graph_def);
107 NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
108 bool result;
109 EXPECT_TRUE(GetScalarConstNodeValue<bool>(*bool_node, &result).ok());
110 EXPECT_EQ(result, true);
111 }
112
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithNonConst)113 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithNonConst) {
114 GraphDef graph_def;
115 MutableGraphView graph(&graph_def);
116 NodeDef* non_const = AddScalarPlaceholder(DT_INT64, &graph);
117 int64_t result;
118 Status s = GetScalarConstNodeValue<int64_t>(*non_const, &result);
119 EXPECT_FALSE(s.ok());
120 EXPECT_EQ(s.error_message(),
121 "Node Placeholder is not a Const node. Op: Placeholder");
122 }
123
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithType)124 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithType) {
125 GraphDef graph_def;
126 MutableGraphView graph(&graph_def);
127 NodeDef* int64_node = AddScalarConstNode<int64_t>(128, &graph);
128 bool result;
129 Status s = GetScalarConstNodeValue<bool>(*int64_node, &result);
130 EXPECT_FALSE(s.ok());
131 EXPECT_EQ(s.error_message(),
132 "Node Const should have type bool but has type: int64");
133 }
134
TEST(GraphUtilsTest,GetScalarConstNodeErrorWithVector)135 TEST(GraphUtilsTest, GetScalarConstNodeErrorWithVector) {
136 NodeDef node;
137 node.set_name("Const");
138 node.set_op("Const");
139
140 (*node.mutable_attr())["dtype"].set_type(DT_INT64);
141 auto tensor = (*node.mutable_attr())["value"].mutable_tensor();
142 tensor->set_dtype(DT_INT64);
143 tensor->mutable_tensor_shape()->mutable_dim()->Add()->set_size(1);
144 tensor->add_int64_val(128);
145
146 int64_t result;
147 Status s = GetScalarConstNodeValue<int64_t>(node, &result);
148 EXPECT_FALSE(s.ok());
149 EXPECT_EQ(s.error_message(),
150 "Node Const should be a scalar but has shape: [1]");
151 }
152
TEST(GraphUtilsTest,Compare)153 TEST(GraphUtilsTest, Compare) {
154 GraphDef graph_def_a;
155 MutableGraphView graph_a(&graph_def_a);
156 GraphDef graph_def_b;
157 MutableGraphView graph_b(&graph_def_b);
158
159 EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
160
161 AddNode("A", "OpA", {}, {}, &graph_a);
162 AddNode("B", "OpB", {"A"}, {}, &graph_a);
163 EXPECT_FALSE(Compare(graph_def_a, graph_def_b));
164
165 graph_def_b.mutable_node()->CopyFrom(graph_def_a.node());
166 EXPECT_TRUE(Compare(graph_def_a, graph_def_b));
167 }
168
TEST(GraphUtilsTest,ContainsGraphNodeWithName)169 TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
170 GraphDef graph_def;
171 MutableGraphView graph(&graph_def);
172 EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
173
174 AddNode("A", "OpA", {}, {}, &graph);
175 EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
176
177 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
178 EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
179 }
180
TEST(GraphUtilsTest,ContainsGraphFunctionWithName)181 TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
182 FunctionDefLibrary library;
183 EXPECT_FALSE(ContainsGraphFunctionWithName("new_function", library));
184 FunctionDef* new_function = library.add_function();
185 SetUniqueGraphFunctionName("new_function", &library, new_function);
186
187 EXPECT_TRUE(
188 ContainsGraphFunctionWithName(new_function->signature().name(), library));
189 }
190
TEST(GraphUtilsTest,ContainsNodeWithOp)191 TEST(GraphUtilsTest, ContainsNodeWithOp) {
192 GraphDef graph_def;
193 MutableGraphView graph(&graph_def);
194 EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
195
196 AddNode("A", "OpA", {}, {}, &graph);
197 EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
198
199 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
200 EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
201 }
202
TEST(GraphUtilsTest,FindGraphNodeWithName)203 TEST(GraphUtilsTest, FindGraphNodeWithName) {
204 GraphDef graph_def;
205 MutableGraphView graph(&graph_def);
206 EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
207
208 AddNode("A", "OpA", {}, {}, &graph);
209 EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
210
211 EXPECT_TRUE(graph.DeleteNodes({"A"}).ok());
212 EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
213 }
214
TEST(GraphUtilsTest,FindGraphFunctionWithName)215 TEST(GraphUtilsTest, FindGraphFunctionWithName) {
216 FunctionDefLibrary library;
217 EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
218 FunctionDef* new_function = library.add_function();
219 SetUniqueGraphFunctionName("new_function", &library, new_function);
220
221 EXPECT_NE(
222 FindGraphFunctionWithName(new_function->signature().name(), library), -1);
223 }
224
TEST(GraphUtilsTest,FindGraphNodeWithOp)225 TEST(GraphUtilsTest, FindGraphNodeWithOp) {
226 GraphDef graph_def;
227 MutableGraphView graph(&graph_def);
228 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
229
230 AddNode("A", "OpA", {}, {}, &graph);
231 AddNode("B", "OpB", {"A"}, {}, &graph);
232 AddNode("A2", "OpA", {"A"}, {}, &graph);
233 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
234
235 EXPECT_TRUE(graph.DeleteNodes({"B"}).ok());
236 EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
237 EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
238 }
239
TEST(GraphUtilsTest,FindAllGraphNodesWithOp)240 TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
241 GraphDef graph_def;
242 MutableGraphView graph(&graph_def);
243 EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
244
245 AddNode("A", "OpA", {}, {}, &graph);
246 AddNode("B", "OpB", {"A"}, {}, &graph);
247 AddNode("A2", "OpA", {"B"}, {}, &graph);
248 std::vector<int> result_indices =
249 FindAllGraphNodesWithOp("OpA", *graph.graph());
250 EXPECT_EQ(result_indices.size(), 2);
251 EXPECT_EQ(result_indices.at(0), 0);
252 EXPECT_EQ(result_indices.at(1), 2);
253
254 EXPECT_TRUE(graph.DeleteNodes({"A2"}).ok());
255 std::vector<int> result_indices_new =
256 FindAllGraphNodesWithOp("OpA", *graph.graph());
257 EXPECT_EQ(result_indices_new.size(), 1);
258 EXPECT_EQ(result_indices_new.at(0), 0);
259 }
260
TEST(GraphUtilsTest,SetUniqueGraphNodeName)261 TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
262 GraphDef graph_def;
263 MutableGraphView graph(&graph_def);
264
265 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
266 NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
267 EXPECT_NE(node1->name(), node2->name());
268
269 EXPECT_TRUE(graph.DeleteNodes({node1->name()}).ok());
270 NodeDef* node3 = AddNode("", "A", {}, {}, &graph);
271 EXPECT_NE(node2->name(), node3->name());
272 }
273
TEST(GraphUtilsTest,SetUniqueGraphFunctionName)274 TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
275 FunctionDefLibrary library;
276 FunctionDef* new_function = library.add_function();
277 SetUniqueGraphFunctionName("new_function", &library, new_function);
278
279 FunctionDef* other_function = library.add_function();
280 SetUniqueGraphFunctionName("new_function", &library, other_function);
281 EXPECT_NE(new_function->signature().name(),
282 other_function->signature().name());
283 }
284
TEST(GraphUtilsTest,GetInputNode)285 TEST(GraphUtilsTest, GetInputNode) {
286 GraphDef graph_def;
287 MutableGraphView graph(&graph_def);
288
289 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
290 NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
291
292 EXPECT_EQ(GetInputNode(*node2, graph), node1);
293 EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
294 }
295
TEST(GraphUtilsTest,GetIthInputNode)296 TEST(GraphUtilsTest, GetIthInputNode) {
297 GraphDef graph_def;
298 MutableGraphView graph(&graph_def);
299
300 NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
301 NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
302 NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
303
304 EXPECT_EQ(GetInputNode(*node3, graph), node1);
305 EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
306 EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
307 EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
308 EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
309 }
310
TEST(GraphUtilsTest,EnsureNodeNamesUnique)311 TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
312 Graph g(OpRegistry::Global());
313
314 Node *const_0, *const_1, *const_2;
315
316 // Arbitrary const
317 Tensor tensor(DT_INT32, {});
318 tensor.scalar<int32>()() = 5;
319
320 for (auto node : {&const_0, &const_1}) {
321 TF_EXPECT_OK(NodeBuilder("Const", "Const")
322 .Attr("value", tensor)
323 .Attr("dtype", DT_INT32)
324 .Finalize(&g, node));
325 }
326 // Make sure generated name doesn't clash with existing name either
327 TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
328 .Attr("value", tensor)
329 .Attr("dtype", DT_INT32)
330 .Finalize(&g, &const_2));
331
332 TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
333 EXPECT_NE(const_0->name(), const_1->name());
334 EXPECT_NE(const_1->name(), const_2->name());
335 EXPECT_NE(const_0->name(), const_2->name());
336 }
337
TEST(GraphUtilsTest,TestGetFetchNode)338 TEST(GraphUtilsTest, TestGetFetchNode) {
339 GrapplerItem item;
340 MutableGraphView graph(&item.graph);
341
342 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
343 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
344 NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
345 item.fetch.push_back(node3->name());
346
347 NodeDef* sink_node;
348 TF_EXPECT_OK(GetFetchNode(graph, item, &sink_node));
349 EXPECT_EQ(sink_node->name(), node3->name());
350 }
351
TEST(GraphUtilsTest,TestFindSinkNodeMultipleFetches)352 TEST(GraphUtilsTest, TestFindSinkNodeMultipleFetches) {
353 GrapplerItem item;
354 MutableGraphView graph(&item.graph);
355
356 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
357 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
358 NodeDef* node3 = AddNode("node3", "Identity", {node2->name()}, {}, &graph);
359 item.fetch.push_back(node2->name());
360 item.fetch.push_back(node3->name());
361
362 NodeDef* sink_node;
363 Status s = GetFetchNode(graph, item, &sink_node);
364 EXPECT_FALSE(s.ok());
365 }
366
TEST(GraphUtilsTest,TestFindSinkNodeNoFetches)367 TEST(GraphUtilsTest, TestFindSinkNodeNoFetches) {
368 GrapplerItem item;
369 MutableGraphView graph(&item.graph);
370
371 NodeDef* node1 = AddNode("node1", "Identity", {}, {}, &graph);
372 NodeDef* node2 = AddNode("node2", "Identity", {node1->name()}, {}, &graph);
373 AddNode("node3", "Identity", {node2->name()}, {}, &graph);
374
375 NodeDef* sink_node;
376 Status s = GetFetchNode(graph, item, &sink_node);
377 EXPECT_FALSE(s.ok());
378 }
379
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsNoShapes)380 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsNoShapes) {
381 NodeDef from = NDef("range", "RangeDataset", {},
382 {{kOutputTypes, gtl::ArraySlice<DataType>{}}});
383 NodeDef to_node;
384 EXPECT_FALSE(CopyShapesAndTypesAttrs(from, &to_node));
385 }
386
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsNoTypes)387 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsNoTypes) {
388 NodeDef from = NDef("range", "RangeDataset", {},
389 {{kOutputShapes, gtl::ArraySlice<TensorShape>{}}});
390 NodeDef to_node;
391 EXPECT_FALSE(CopyShapesAndTypesAttrs(from, &to_node));
392 }
393
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsOutputTypes)394 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsOutputTypes) {
395 NodeDef from = NDef("range", "RangeDataset", {},
396 {{kOutputShapes, 666}, {kOutputTypes, 888}});
397 NodeDef to_node;
398 EXPECT_TRUE(CopyShapesAndTypesAttrs(from, &to_node));
399 EXPECT_EQ(to_node.attr().at(kOutputShapes).i(), 666);
400 EXPECT_EQ(to_node.attr().at(kOutputTypes).i(), 888);
401 }
402
TEST(GraphUtilsTest,TestCopyShapesAndTypesAttrsToutputTypes)403 TEST(GraphUtilsTest, TestCopyShapesAndTypesAttrsToutputTypes) {
404 NodeDef from = NDef("tensor", "TensorDataset", {},
405 {{kOutputShapes, 666}, {kToutputTypes, 888}});
406 NodeDef to_node;
407 EXPECT_TRUE(CopyShapesAndTypesAttrs(from, &to_node));
408 EXPECT_EQ(to_node.attr().at(kOutputShapes).i(), 666);
409 EXPECT_EQ(to_node.attr().at(kOutputTypes).i(), 888);
410 }
411
TEST(GraphUtilsTest,TestSetMetadataName)412 TEST(GraphUtilsTest, TestSetMetadataName) {
413 NodeDef node = NDef("range", "RangeDataset", {},
414 {{kOutputShapes, 666}, {kOutputTypes, 888}});
415 EXPECT_TRUE(SetMetadataName("metadata_name", &node).ok());
416 EXPECT_TRUE(node.attr().contains("metadata"));
417 data::Metadata metadata;
418 metadata.ParseFromString(node.attr().at("metadata").s());
419 EXPECT_EQ("metadata_name", metadata.name());
420 EXPECT_FALSE(SetMetadataName("new_metadata_name", &node).ok());
421 }
422
423 } // namespace
424 } // namespace graph_utils
425 } // namespace grappler
426 } // namespace tensorflow
427