1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
23 #include "tensorflow/core/grappler/utils.h"
24 #include "tensorflow/core/grappler/utils/graph_view.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace grappler {
31
32 class LoopOptimizerTest : public GrapplerTest {
33 protected:
34 // These helpers always sets T=DT_FLOAT.
AddEnterNode(const string & name,const string & frame,const bool is_constant,const int piterations,const std::vector<string> & inputs,GraphDef * graph) const35 void AddEnterNode(const string& name, const string& frame,
36 const bool is_constant, const int piterations,
37 const std::vector<string>& inputs, GraphDef* graph) const {
38 std::vector<std::pair<string, AttrValue>> attributes;
39 AttrValue type;
40 type.set_type(DT_FLOAT);
41 attributes.emplace_back("T", type);
42 AttrValue frame_name;
43 frame_name.set_s(frame);
44 attributes.emplace_back("frame_name", frame_name);
45 AttrValue is_const;
46 is_const.set_b(is_constant);
47 attributes.emplace_back("is_constant", is_const);
48 AttrValue parallel_iterations;
49 parallel_iterations.set_i(piterations);
50 attributes.emplace_back("parallel_iterations", parallel_iterations);
51 AddNode(name, "Enter", inputs, attributes, graph);
52 }
53
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const54 void AddSimpleNode(const string& name, const string& op,
55 const std::vector<string>& inputs, GraphDef* graph) const {
56 std::vector<std::pair<string, AttrValue>> attributes;
57 AttrValue type;
58 type.set_type(DT_FLOAT);
59 attributes.emplace_back("T", type);
60 AddNode(name, op, inputs, attributes, graph);
61 }
62
EnableOnlyLoopInvariantNodeMotion(LoopOptimizer * optimizer)63 void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
64 DisableAllStages(optimizer);
65 optimizer->options_.enable_loop_invariant_node_motion = true;
66 }
67
EnableOnlyStackPushRemoval(LoopOptimizer * optimizer)68 void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
69 DisableAllStages(optimizer);
70 optimizer->options_.enable_stack_push_removal = true;
71 }
72
73 private:
DisableAllStages(LoopOptimizer * optimizer)74 void DisableAllStages(LoopOptimizer* optimizer) {
75 LoopOptimizer::LoopOptimizerOptions options;
76 options.enable_loop_invariant_node_motion = false;
77 options.enable_stack_push_removal = false;
78 optimizer->options_ = options;
79 }
80 };
81
TEST_F(LoopOptimizerTest,Basic)82 TEST_F(LoopOptimizerTest, Basic) {
83 GraphDef graph;
84 AddSimpleNode("In", "Identity", {}, &graph);
85 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
86 &graph);
87 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
88 &graph);
89 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
90 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
91 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
92 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
93 AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
94 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
95 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
96 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
97 AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
98 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
99 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
100
101 GrapplerItem item;
102 item.graph = graph;
103
104 LoopOptimizer optimizer;
105 EnableOnlyLoopInvariantNodeMotion(&optimizer);
106 GraphDef output;
107 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
108
109 { // Original graph.
110 Status status;
111 utils::GraphView view(&graph, &status);
112 TF_ASSERT_OK(status);
113 FrameView frames;
114 TF_EXPECT_OK(frames.InferFromGraphView(view));
115
116 EXPECT_EQ(frames.num_frames(), 1);
117 const auto* invariant_add_node = view.GetNode("InvariantAdd");
118 ASSERT_NE(invariant_add_node, nullptr);
119 const auto* invariant_add_node_def = invariant_add_node->node();
120 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
121 EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
122 const auto* variant_add_node = view.GetNode("VariantAdd");
123 ASSERT_NE(variant_add_node, nullptr);
124 const auto* variant_add_node_def = variant_add_node->node();
125 ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
126 EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
127 }
128
129 { // Optimized graph.
130 Status status;
131 utils::GraphView view(&output, &status);
132 TF_ASSERT_OK(status);
133 FrameView frames;
134 TF_EXPECT_OK(frames.InferFromGraphView(view));
135
136 EXPECT_EQ(frames.num_frames(), 1);
137 const auto* invariant_add_node = view.GetNode("InvariantAdd");
138 ASSERT_NE(invariant_add_node, nullptr);
139 const auto* invariant_add_node_def = invariant_add_node->node();
140 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
141 const auto* variant_add_node = view.GetNode("VariantAdd");
142 ASSERT_NE(variant_add_node, nullptr);
143 const auto* variant_add_node_def = variant_add_node->node();
144 ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
145 EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
146 }
147 }
148
TEST_F(LoopOptimizerTest,Const)149 TEST_F(LoopOptimizerTest, Const) {
150 GraphDef graph;
151 AddSimpleNode("In", "Identity", {}, &graph);
152 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
153 &graph);
154 AddSimpleNode("Const", "Const", {"^Identity"}, &graph);
155 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "Const"}, &graph);
156 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
157 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
158 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
159 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
160 AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
161 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
162 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
163 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
164 AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
165 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
166 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
167
168 GrapplerItem item;
169 item.graph = graph;
170
171 LoopOptimizer optimizer;
172 EnableOnlyLoopInvariantNodeMotion(&optimizer);
173 GraphDef output;
174 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
175
176 { // Original graph.
177 Status status;
178 utils::GraphView view(&graph, &status);
179 TF_ASSERT_OK(status);
180 FrameView frames;
181 TF_EXPECT_OK(frames.InferFromGraphView(view));
182
183 EXPECT_EQ(frames.num_frames(), 1);
184 const auto* invariant_add_node = view.GetNode("InvariantAdd");
185 ASSERT_NE(invariant_add_node, nullptr);
186 const auto* invariant_add_node_def = invariant_add_node->node();
187 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
188 EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
189 const auto* const_node = view.GetNode("Const");
190 ASSERT_NE(const_node, nullptr);
191 const auto* const_node_node_def = const_node->node();
192 ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 1);
193 EXPECT_EQ(frames.Frames(*const_node_node_def).back(), 0);
194 }
195
196 { // Optimized graph.
197 Status status;
198 utils::GraphView view(&output, &status);
199 TF_ASSERT_OK(status);
200 FrameView frames;
201 TF_EXPECT_OK(frames.InferFromGraphView(view));
202
203 EXPECT_EQ(frames.num_frames(), 1);
204 const auto* invariant_add_node = view.GetNode("InvariantAdd");
205 ASSERT_NE(invariant_add_node, nullptr);
206 const auto* invariant_add_node_def = invariant_add_node->node();
207 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
208 const auto* const_node = view.GetNode("Const");
209 ASSERT_NE(const_node, nullptr);
210 const auto* const_node_node_def = const_node->node();
211 ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 0);
212 }
213 }
214
TEST_F(LoopOptimizerTest,ControlOutput)215 TEST_F(LoopOptimizerTest, ControlOutput) {
216 GraphDef graph;
217 AddSimpleNode("In", "Identity", {}, &graph);
218 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
219 &graph);
220 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
221 &graph);
222 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
223 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
224 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
225 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
226 AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y", "^InvariantAdd"},
227 &graph);
228 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
229 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
230 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
231 AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
232 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
233 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
234
235 GrapplerItem item;
236 item.graph = graph;
237
238 LoopOptimizer optimizer;
239 EnableOnlyLoopInvariantNodeMotion(&optimizer);
240 GraphDef output;
241 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
242
243 { // Original graph.
244 Status status;
245 utils::GraphView view(&graph, &status);
246 TF_ASSERT_OK(status);
247 FrameView frames;
248 TF_EXPECT_OK(frames.InferFromGraphView(view));
249
250 EXPECT_EQ(frames.num_frames(), 1);
251 const auto* invariant_add_node = view.GetNode("InvariantAdd");
252 ASSERT_NE(invariant_add_node, nullptr);
253 const auto* invariant_add_node_def = invariant_add_node->node();
254 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
255 EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
256 }
257
258 { // Optimized graph.
259 Status status;
260 utils::GraphView view(&output, &status);
261 TF_ASSERT_OK(status);
262 FrameView frames;
263 TF_EXPECT_OK(frames.InferFromGraphView(view));
264
265 EXPECT_EQ(frames.num_frames(), 1);
266 const auto* invariant_add_node = view.GetNode("InvariantAdd");
267 ASSERT_NE(invariant_add_node, nullptr);
268 const auto* invariant_add_node_def = invariant_add_node->node();
269 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
270 EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
271 }
272 }
273
TEST_F(LoopOptimizerTest,NestedLoop1)274 TEST_F(LoopOptimizerTest, NestedLoop1) {
275 GraphDef graph;
276 AddSimpleNode("In", "Identity", {}, &graph);
277 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
278 &graph);
279 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
280 &graph);
281 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
282 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
283 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
284 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
285 AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
286 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
287 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
288 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
289 AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
290 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
291 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
292
293 AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
294 {"VariantAdd"}, &graph);
295 AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
296 &graph);
297 AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
298 AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
299 {"VariantEnter"}, &graph);
300 AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
301 AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
302 AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
303 AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
304 AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
305 AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
306 AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
307 AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
308
309 GrapplerItem item;
310 item.graph = graph;
311
312 LoopOptimizer optimizer;
313 EnableOnlyLoopInvariantNodeMotion(&optimizer);
314 GraphDef output;
315 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
316
317 { // Original graph.
318 Status status;
319 utils::GraphView view(&graph, &status);
320 TF_ASSERT_OK(status);
321 FrameView frames;
322 TF_EXPECT_OK(frames.InferFromGraphView(view));
323
324 EXPECT_EQ(frames.num_frames(), 2);
325 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
326 ASSERT_NE(invariant_add_2_node, nullptr);
327 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
328 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
329 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
330 const auto* variant_add_2_node = view.GetNode("VariantAdd2");
331 ASSERT_NE(variant_add_2_node, nullptr);
332 const auto* variant_add_2_node_def = variant_add_2_node->node();
333 ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
334 EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
335 const auto* invariant_add_node = view.GetNode("InvariantAdd");
336 ASSERT_NE(invariant_add_node, nullptr);
337 const auto* invariant_add_node_def = invariant_add_node->node();
338 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
339 EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
340 }
341
342 { // Optimized graph.
343 Status status;
344 utils::GraphView view(&output, &status);
345 TF_ASSERT_OK(status);
346 FrameView frames;
347 TF_EXPECT_OK(frames.InferFromGraphView(view));
348
349 EXPECT_EQ(frames.num_frames(), 2);
350 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
351 ASSERT_NE(invariant_add_2_node, nullptr);
352 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
353 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
354 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
355 const auto* variant_add_2_node = view.GetNode("VariantAdd2");
356 ASSERT_NE(variant_add_2_node, nullptr);
357 const auto* variant_add_2_node_def = variant_add_2_node->node();
358 ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
359 EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
360 const auto* invariant_add_node = view.GetNode("InvariantAdd");
361 ASSERT_NE(invariant_add_node, nullptr);
362 const auto* invariant_add_node_def = invariant_add_node->node();
363 ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
364 }
365 }
366
TEST_F(LoopOptimizerTest,NestedLoop2)367 TEST_F(LoopOptimizerTest, NestedLoop2) {
368 GraphDef graph;
369 AddSimpleNode("In", "Identity", {}, &graph);
370 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
371 &graph);
372 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
373 &graph);
374 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
375 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
376 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
377 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
378 AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
379 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
380 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
381 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
382 AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
383 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
384 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
385
386 AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
387 {"InvariantAdd"}, &graph);
388 AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
389 &graph);
390 AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
391 AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
392 {"VariantEnter"}, &graph);
393 AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
394 AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
395 AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
396 AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
397 AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
398 AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
399 AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
400 AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
401
402 GrapplerItem item;
403 item.graph = graph;
404
405 LoopOptimizer optimizer;
406 EnableOnlyLoopInvariantNodeMotion(&optimizer);
407 GraphDef output;
408 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
409
410 { // Original graph.
411 Status status;
412 utils::GraphView view(&graph, &status);
413 TF_ASSERT_OK(status);
414 FrameView frames;
415 TF_EXPECT_OK(frames.InferFromGraphView(view));
416
417 EXPECT_EQ(frames.num_frames(), 2);
418 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
419 ASSERT_NE(invariant_add_2_node, nullptr);
420 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
421 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
422 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
423 const auto* variant_add_2_node = view.GetNode("VariantAdd2");
424 ASSERT_NE(variant_add_2_node, nullptr);
425 const auto* variant_add_2_node_def = variant_add_2_node->node();
426 ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
427 EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
428 }
429
430 { // Optimized graph.
431 Status status;
432 utils::GraphView view(&output, &status);
433 TF_ASSERT_OK(status);
434 FrameView frames;
435 TF_EXPECT_OK(frames.InferFromGraphView(view));
436
437 EXPECT_EQ(frames.num_frames(), 2);
438 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
439 ASSERT_NE(invariant_add_2_node, nullptr);
440 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
441 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
442 const auto* variant_add_2_node = view.GetNode("VariantAdd2");
443 ASSERT_NE(variant_add_2_node, nullptr);
444 const auto* variant_add_2_node_def = variant_add_2_node->node();
445 ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
446 EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
447 }
448 }
449
TEST_F(LoopOptimizerTest,NestedLoopConst1)450 TEST_F(LoopOptimizerTest, NestedLoopConst1) {
451 GraphDef graph;
452 AddSimpleNode("In", "Identity", {}, &graph);
453 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
454 &graph);
455 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
456 &graph);
457 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
458 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
459 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
460 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
461 AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
462 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
463 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
464 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
465 AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
466 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
467 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
468
469 AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
470 {"VariantAdd"}, &graph);
471 AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
472 AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
473 AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
474 AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
475 {"VariantEnter"}, &graph);
476 AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
477 AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
478 AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
479 AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
480 AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
481 AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
482 AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
483 AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
484
485 GrapplerItem item;
486 item.graph = graph;
487
488 LoopOptimizer optimizer;
489 EnableOnlyLoopInvariantNodeMotion(&optimizer);
490 GraphDef output;
491 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
492
493 { // Original graph.
494 Status status;
495 utils::GraphView view(&graph, &status);
496 TF_ASSERT_OK(status);
497 FrameView frames;
498 TF_EXPECT_OK(frames.InferFromGraphView(view));
499
500 EXPECT_EQ(frames.num_frames(), 2);
501 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
502 ASSERT_NE(invariant_add_2_node, nullptr);
503 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
504 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
505 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
506 const auto* const_2_node = view.GetNode("Const2");
507 ASSERT_NE(const_2_node, nullptr);
508 const auto* const_2_node_def = const_2_node->node();
509 ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
510 EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
511 }
512
513 { // Optimized graph.
514 Status status;
515 utils::GraphView view(&output, &status);
516 TF_ASSERT_OK(status);
517 FrameView frames;
518 TF_EXPECT_OK(frames.InferFromGraphView(view));
519
520 EXPECT_EQ(frames.num_frames(), 2);
521 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
522 ASSERT_NE(invariant_add_2_node, nullptr);
523 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
524 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
525 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
526 const auto* const_2_node = view.GetNode("Const2");
527 ASSERT_NE(const_2_node, nullptr);
528 const auto* const_2_node_def = const_2_node->node();
529 ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 1);
530 EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 0);
531 }
532 }
533
TEST_F(LoopOptimizerTest,NestedLoopConst2)534 TEST_F(LoopOptimizerTest, NestedLoopConst2) {
535 GraphDef graph;
536 AddSimpleNode("In", "Identity", {}, &graph);
537 AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
538 &graph);
539 AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
540 &graph);
541 AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
542 AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
543 AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
544 AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
545 AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
546 AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
547 AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
548 AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
549 AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
550 AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
551 AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
552
553 AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
554 {"InvariantAdd"}, &graph);
555 AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
556 AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
557 AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
558 AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
559 {"VariantEnter"}, &graph);
560 AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
561 AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
562 AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
563 AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
564 AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
565 AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
566 AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
567 AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
568
569 GrapplerItem item;
570 item.graph = graph;
571
572 LoopOptimizer optimizer;
573 EnableOnlyLoopInvariantNodeMotion(&optimizer);
574 GraphDef output;
575 TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
576
577 { // Original graph.
578 Status status;
579 utils::GraphView view(&graph, &status);
580 TF_ASSERT_OK(status);
581 FrameView frames;
582 TF_EXPECT_OK(frames.InferFromGraphView(view));
583
584 EXPECT_EQ(frames.num_frames(), 2);
585 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
586 ASSERT_NE(invariant_add_2_node, nullptr);
587 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
588 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
589 EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
590 const auto* const_2_node = view.GetNode("Const2");
591 ASSERT_NE(const_2_node, nullptr);
592 const auto* const_2_node_def = const_2_node->node();
593 ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
594 EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
595 }
596
597 { // Optimized graph.
598 Status status;
599 utils::GraphView view(&output, &status);
600 TF_ASSERT_OK(status);
601 FrameView frames;
602 TF_EXPECT_OK(frames.InferFromGraphView(view));
603
604 EXPECT_EQ(frames.num_frames(), 2);
605 const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
606 ASSERT_NE(invariant_add_2_node, nullptr);
607 const auto* invariant_add_2_node_def = invariant_add_2_node->node();
608 ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
609 const auto* const_2_node = view.GetNode("Const2");
610 ASSERT_NE(const_2_node, nullptr);
611 const auto* const_2_node_def = const_2_node->node();
612 ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 0);
613 }
614 }
615
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)616 void VerifyGraphsEqual(const GraphDef& original_graph,
617 const GraphDef& optimized_graph, const string& func) {
618 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
619 for (int i = 0; i < original_graph.node_size(); ++i) {
620 const NodeDef& original = original_graph.node(i);
621 const NodeDef& optimized = optimized_graph.node(i);
622 EXPECT_EQ(optimized.name(), original.name()) << func;
623 EXPECT_EQ(optimized.op(), original.op()) << func;
624 ASSERT_EQ(optimized.input_size(), original.input_size()) << func;
625 for (int j = 0; j < original.input_size(); ++j) {
626 EXPECT_EQ(optimized.input(j), original.input(j)) << func;
627 }
628 }
629 }
630
TEST_F(LoopOptimizerTest,NoOp)631 TEST_F(LoopOptimizerTest, NoOp) {
632 // This trivial graph is so basic there's nothing to optimize.
633 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
634 GrapplerItem item;
635 CHECK(fake_input.NextItem(&item));
636
637 LoopOptimizer optimizer;
638 EnableOnlyStackPushRemoval(&optimizer);
639 GraphDef output;
640 Status status = optimizer.Optimize(nullptr, item, &output);
641 TF_EXPECT_OK(status);
642
643 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
644 }
645
TEST_F(LoopOptimizerTest,RemovePushNoOp)646 TEST_F(LoopOptimizerTest, RemovePushNoOp) {
647 GrapplerItem item;
648 GraphDef& graph = item.graph;
649 AddSimpleNode("c", "Const", {}, &graph);
650 // Stack with corresponding push/pop.
651 AddSimpleNode("stack1", "StackV2", {}, &graph);
652 AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
653 AddSimpleNode("pop1", "StackPopV2", {"stack1"}, &graph);
654 AddSimpleNode("id1", "Identity", {"pop1"}, &graph);
655 // Stack with corresponding push/pop behind Enter.
656 AddSimpleNode("stack2", "StackV2", {}, &graph);
657 AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
658 AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
659 AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
660 AddSimpleNode("pop2", "StackPopV2", {"enter2_stack2"}, &graph);
661 AddSimpleNode("id2", "Identity", {"pop2"}, &graph);
662 // Stack with unexpected op type in fanout of Stack.
663 AddSimpleNode("stack3", "StackV2", {}, &graph);
664 AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
665 AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
666
667 LoopOptimizer optimizer;
668 EnableOnlyStackPushRemoval(&optimizer);
669 GraphDef output;
670 Status status = optimizer.Optimize(nullptr, item, &output);
671 TF_EXPECT_OK(status);
672 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
673 }
674
TEST_F(LoopOptimizerTest,RemovePushNoPopButStackLives)675 TEST_F(LoopOptimizerTest, RemovePushNoPopButStackLives) {
676 GrapplerItem item;
677 GraphDef& graph = item.graph;
678 AddSimpleNode("c", "Const", {}, &graph);
679 // Stack with corresponding push
680 AddSimpleNode("stack1", "StackV2", {}, &graph);
681 AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
682 // Stack with corresponding push behind Enter.
683 AddSimpleNode("stack2", "StackV2", {}, &graph);
684 AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
685 AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
686 AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
687 item.keep_ops.push_back("stack1");
688 item.keep_ops.push_back("stack2");
689
690 LoopOptimizer optimizer;
691 EnableOnlyStackPushRemoval(&optimizer);
692 GraphDef output;
693 Status status = optimizer.Optimize(nullptr, item, &output);
694 TF_EXPECT_OK(status);
695 VerifyGraphsEqual(item.graph, output, __FUNCTION__);
696 }
697
TEST_F(LoopOptimizerTest,RemovePushWithoutMatchingPop)698 TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
699 GrapplerItem item;
700 GraphDef& graph = item.graph;
701 AddSimpleNode("c", "Const", {}, &graph);
702 // Push without Pop.
703 AddSimpleNode("stack1", "StackV2", {}, &graph);
704 AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
705 // Push without Pop behind Enter.
706 AddSimpleNode("stack2", "StackV2", {}, &graph);
707 AddEnterNode("enter_c", "frame_name", false, 1, {"c"}, &graph);
708 AddEnterNode("enter_stack2", "frame_name", false, 1, {"stack2"}, &graph);
709 AddSimpleNode("push2", "StackPushV2", {"enter_stack2", "enter_c"}, &graph);
710 // Pop without consumer.
711 AddSimpleNode("stack3", "StackV2", {}, &graph);
712 AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
713 AddSimpleNode("pop3", "StackPopV2", {"stack3"}, &graph);
714 // Push for a Pop without consumer that is fetched should not be removed.
715 AddSimpleNode("stack4", "StackV2", {}, &graph);
716 AddSimpleNode("push4", "StackPushV2", {"stack4", "c"}, &graph);
717 AddSimpleNode("pop4", "StackPopV2", {"stack4"}, &graph);
718
719 item.fetch.push_back("pop4");
720
721 LoopOptimizer optimizer;
722 EnableOnlyStackPushRemoval(&optimizer);
723 GraphDef output;
724 Status status = optimizer.Optimize(nullptr, item, &output);
725 TF_EXPECT_OK(status);
726
727 EXPECT_EQ(output.node_size(), 13);
728 for (int i = 0; i < output.node_size(); ++i) {
729 const NodeDef& node = output.node(i);
730 if (node.name() == "push1") {
731 EXPECT_EQ(node.op(), "Identity");
732 ASSERT_EQ(node.input_size(), 2);
733 EXPECT_EQ(node.input(0), "c");
734 EXPECT_EQ(node.input(1), "^stack1");
735 } else if (node.name() == "push2") {
736 EXPECT_EQ(node.op(), "Identity");
737 ASSERT_EQ(node.input_size(), 2);
738 EXPECT_EQ(node.input(0), "enter_c");
739 EXPECT_EQ(node.input(1), "^enter_stack2");
740 } else if (node.name() == "push3") {
741 EXPECT_EQ(node.op(), "Identity");
742 ASSERT_EQ(node.input_size(), 2);
743 EXPECT_EQ(node.input(0), "c");
744 EXPECT_EQ(node.input(1), "^stack3");
745 } else {
746 const NodeDef& orig_node = item.graph.node(i);
747 EXPECT_EQ(node.ShortDebugString(), orig_node.ShortDebugString());
748 }
749 }
750 }
751
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantCondition)752 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
753 Scope scope = Scope::NewRootScope();
754 Output v_in = ops::Const<float>(scope.WithOpName("v_in"), {123.0}, {});
755
756 Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({}));
757 ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
758 Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
759 Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true);
760
761 Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({}));
762 ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2);
763 Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false);
764 Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true);
765
766 Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({}));
767 ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3);
768 Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false);
769 Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true);
770
771 Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({}));
772 ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4);
773 Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false);
774 Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true);
775
776 ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1});
777 ops::Merge m2(scope.WithOpName("m2"), {v_in, square1});
778 ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
779 ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
780 ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
781
782 ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
783 Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
784 Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true);
785 ops::Merge m8(scope.WithOpName("m8"), {id1, id2});
786
787 ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1);
788 Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false);
789 Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true);
790 ops::Merge m9(scope.WithOpName("m9"), {id3, id4});
791
792 GrapplerItem item;
793 item.fetch.push_back("m8");
794 item.fetch.push_back("id4");
795
796 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
797
798 LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
799 GraphDef output;
800 Status status = optimizer.Optimize(nullptr, item, &output);
801 TF_CHECK_OK(status);
802
803 for (const NodeDef& node : output.node()) {
804 // These nodes should have been pruned
805 EXPECT_NE(node.name(), "Square1");
806 EXPECT_NE(node.name(), "Sqrt2");
807 EXPECT_NE(node.name(), "m5");
808
809 if (node.name() == "m1") {
810 // sqrt1 is dead
811 EXPECT_EQ(node.op(), "Identity");
812 ASSERT_EQ(node.input_size(), 1);
813 EXPECT_EQ(node.input(0), "square1");
814 } else if (node.name() == "m2") {
815 // both inputs are alive
816 EXPECT_EQ(node.op(), "Merge");
817 ASSERT_EQ(node.input_size(), 2);
818 EXPECT_EQ(node.input(0), "v_in");
819 EXPECT_EQ(node.input(1), "square1");
820 } else if (node.name() == "m3") {
821 // sqrt1 is dead
822 EXPECT_EQ(node.op(), "Identity");
823 ASSERT_EQ(node.input_size(), 1);
824 EXPECT_EQ(node.input(0), "v_in");
825 } else if (node.name() == "m4") {
826 // both inputs are alive
827 EXPECT_EQ(node.op(), "Merge");
828 ASSERT_EQ(node.input_size(), 2);
829 EXPECT_EQ(node.input(0), "square1");
830 EXPECT_EQ(node.input(1), "sqrt2");
831 } else if (node.name() == "m8") {
832 // The node is to be preserved because of a fetch
833 EXPECT_EQ(node.op(), "Merge");
834 ASSERT_EQ(node.input_size(), 2);
835 EXPECT_EQ(node.input(0), "id1");
836 EXPECT_EQ(node.input(1), "id2");
837 } else if (node.name() == "m9") {
838 // The node is to be preserved because of a fetch
839 EXPECT_EQ(node.op(), "Merge");
840 ASSERT_EQ(2, node.input_size());
841 EXPECT_EQ(node.input(0), "id3");
842 EXPECT_EQ(node.input(1), "id4");
843 } else if (node.name() == "switch1") {
844 // The node can be replaced by Identity with control_dependency
845 EXPECT_EQ(node.op(), "Identity");
846 ASSERT_EQ(node.input_size(), 2);
847 EXPECT_EQ(node.input(0), "v_in");
848 EXPECT_EQ(node.input(1), "^ctrl1");
849 } else if (node.name() == "switch2") {
850 // The node can be replaced by Identity with control_dependency
851 EXPECT_EQ(node.op(), "Identity");
852 ASSERT_EQ(node.input_size(), 2);
853 EXPECT_EQ(node.input(0), "v_in");
854 EXPECT_EQ(node.input(1), "^ctrl2");
855 } else if (node.name() == "switch3") {
856 // The node can be replaced by Identity with control_dependency
857 EXPECT_EQ(node.op(), "Identity");
858 ASSERT_EQ(node.input_size(), 2);
859 EXPECT_EQ(node.input(0), "v_in");
860 EXPECT_EQ(node.input(1), "^ctrl3");
861 } else if (node.name() == "switch4") {
862 // The node can be replaced by Identity with control_dependency
863 EXPECT_EQ(node.op(), "Identity");
864 ASSERT_EQ(node.input_size(), 2);
865 EXPECT_EQ(node.input(0), "v_in");
866 EXPECT_EQ(node.input(1), "^ctrl4");
867 } else if (node.name() == "switch5") {
868 // The node should remain unchanged
869 EXPECT_EQ(node.op(), "Switch");
870 ASSERT_EQ(node.input_size(), 2);
871 EXPECT_EQ(node.input(0), "v_in");
872 EXPECT_EQ(node.input(1), "ctrl1");
873 } else if (node.name() == "switch6") {
874 // The node should remain unchanged
875 EXPECT_EQ(node.op(), "Switch");
876 ASSERT_EQ(node.input_size(), 2);
877 EXPECT_EQ(node.input(0), "v_in");
878 EXPECT_EQ(node.input(1), "ctrl1");
879 }
880 }
881
882 auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
883 ASSERT_EQ(tensors_expected.size(), 2);
884
885 auto tensors = EvaluateNodes(output, {"m8", "m9"});
886 ASSERT_EQ(tensors.size(), 2);
887
888 test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
889 test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
890 }
891
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantCondition2)892 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition2) {
893 Scope scope = Scope::NewRootScope();
894 Output v_in = ops::Const<float>(scope.WithOpName("v_in"), {123.0}, {});
895
896 Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), true, TensorShape({}));
897 ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
898
899 Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
900
901 Output add1 =
902 ops::Add(scope.WithOpName("add1"), s1.output_true, s1.output_true);
903
904 Output const2 = ops::Const<float>(scope.WithOpName("const2"), {20.0}, {});
905 Output add2 = ops::Add(scope.WithOpName("add2"), s1.output_true, const2);
906
907 Output sub1 = ops::Sub(scope.WithOpName("sub1"), add1, add2);
908
909 ops::Merge m1(scope.WithOpName("m1"), {square1, sub1});
910 Output add3 = ops::Add(scope.WithOpName("add3"), m1.output, const2);
911
912 GrapplerItem item;
913 item.fetch.push_back("add3");
914
915 TF_CHECK_OK(scope.ToGraphDef(&item.graph));
916
917 LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
918 GraphDef output;
919 Status status = optimizer.Optimize(nullptr, item, &output);
920 TF_CHECK_OK(status);
921
922 for (const NodeDef& node : output.node()) {
923 // This node should have been pruned
924 EXPECT_NE(node.name(), "Square1");
925
926 if (node.name() == "m1") {
927 // square1 is dead
928 EXPECT_EQ(node.op(), "Identity");
929 ASSERT_EQ(node.input_size(), 1);
930 EXPECT_EQ(node.input(0), "sub1");
931 } else if (node.name() == "switch1") {
932 EXPECT_EQ(node.op(), "Identity");
933 ASSERT_EQ(node.input_size(), 2);
934 EXPECT_EQ(node.input(0), "v_in");
935 EXPECT_EQ(node.input(1), "^ctrl1");
936 }
937 }
938 }
939
TEST_F(LoopOptimizerTest,RemoveDeadBranchesFullyRemoveDeadBranches)940 TEST_F(LoopOptimizerTest, RemoveDeadBranchesFullyRemoveDeadBranches) {
941 const string gdef_ascii = R"EOF(
942 node {
943 name: "episodicreplaybuffer_add_readvariableop_resource"
944 op: "_Arg"
945 device: "/job:localhost/replica:0/task:0/device:CPU:0"
946 attr {
947 key: "T"
948 value {
949 type: DT_RESOURCE
950 }
951 }
952 attr {
953 key: "index"
954 value {
955 i: 0
956 }
957 }
958 }
959 node {
960 name: "EpisodicReplayBuffer/add/and_1/x"
961 op: "Const"
962 device: "/job:localhost/replica:0/task:0/device:CPU:0"
963 attr {
964 key: "dtype"
965 value {
966 type: DT_BOOL
967 }
968 }
969 attr {
970 key: "value"
971 value {
972 tensor {
973 dtype: DT_BOOL
974 tensor_shape {
975 }
976 bool_val: true
977 }
978 }
979 }
980 }
981 node {
982 name: "EpisodicReplayBuffer/add/begin_episode"
983 op: "Const"
984 device: "/job:localhost/replica:0/task:0/device:CPU:0"
985 attr {
986 key: "dtype"
987 value {
988 type: DT_BOOL
989 }
990 }
991 attr {
992 key: "value"
993 value {
994 tensor {
995 dtype: DT_BOOL
996 tensor_shape {
997 }
998 bool_val: false
999 }
1000 }
1001 }
1002 }
1003 node {
1004 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Switch"
1005 op: "Switch"
1006 input: "EpisodicReplayBuffer/add/and_1/x"
1007 input: "EpisodicReplayBuffer/add/and_1/x"
1008 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1009 attr {
1010 key: "T"
1011 value {
1012 type: DT_BOOL
1013 }
1014 }
1015 }
1016 node {
1017 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/NoOp"
1018 op: "NoOp"
1019 input: "^EpisodicReplayBuffer/add/and_1/x"
1020 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1021 }
1022 node {
1023 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch"
1024 op: "Switch"
1025 input: "EpisodicReplayBuffer/add/and_1/x"
1026 input: "EpisodicReplayBuffer/add/and_1/x"
1027 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1028 attr {
1029 key: "T"
1030 value {
1031 type: DT_BOOL
1032 }
1033 }
1034 attr {
1035 key: "_class"
1036 value {
1037 list {
1038 s: "loc:@EpisodicReplayBuffer/add/assert_equal/All"
1039 }
1040 }
1041 }
1042 }
1043 node {
1044 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_1"
1045 op: "Switch"
1046 input: "EpisodicReplayBuffer/add/begin_episode"
1047 input: "EpisodicReplayBuffer/add/and_1/x"
1048 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1049 attr {
1050 key: "T"
1051 value {
1052 type: DT_BOOL
1053 }
1054 }
1055 attr {
1056 key: "_class"
1057 value {
1058 list {
1059 s: "loc:@EpisodicReplayBuffer/add/begin_episode"
1060 }
1061 }
1062 }
1063 }
1064 node {
1065 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_2"
1066 op: "Switch"
1067 input: "EpisodicReplayBuffer/add/begin_episode"
1068 input: "EpisodicReplayBuffer/add/and_1/x"
1069 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1070 attr {
1071 key: "T"
1072 value {
1073 type: DT_BOOL
1074 }
1075 }
1076 attr {
1077 key: "_class"
1078 value {
1079 list {
1080 s: "loc:@EpisodicReplayBuffer/add/end_episode"
1081 }
1082 }
1083 }
1084 }
1085 node {
1086 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1087 op: "Identity"
1088 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Switch"
1089 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1090 attr {
1091 key: "T"
1092 value {
1093 type: DT_BOOL
1094 }
1095 }
1096 }
1097 node {
1098 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
1099 op: "Const"
1100 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/NoOp"
1101 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1102 attr {
1103 key: "dtype"
1104 value {
1105 type: DT_BOOL
1106 }
1107 }
1108 attr {
1109 key: "value"
1110 value {
1111 tensor {
1112 dtype: DT_BOOL
1113 tensor_shape {
1114 }
1115 tensor_content: "\001"
1116 }
1117 }
1118 }
1119 }
1120 node {
1121 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
1122 op: "Assert"
1123 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch"
1124 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_1"
1125 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_2"
1126 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1127 attr {
1128 key: "T"
1129 value {
1130 list {
1131 type: DT_BOOL
1132 type: DT_BOOL
1133 }
1134 }
1135 }
1136 attr {
1137 key: "summarize"
1138 value {
1139 i: 3
1140 }
1141 }
1142 }
1143 node {
1144 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
1145 op: "Identity"
1146 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1147 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
1148 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1149 attr {
1150 key: "T"
1151 value {
1152 type: DT_BOOL
1153 }
1154 }
1155 attr {
1156 key: "_class"
1157 value {
1158 list {
1159 s: "loc:@EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1160 }
1161 }
1162 }
1163 }
1164 node {
1165 name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1166 op: "Merge"
1167 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
1168 input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
1169 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1170 attr {
1171 key: "N"
1172 value {
1173 i: 2
1174 }
1175 }
1176 attr {
1177 key: "T"
1178 value {
1179 type: DT_BOOL
1180 }
1181 }
1182 }
1183 node {
1184 name: "EpisodicReplayBuffer/add/FloorMod/y"
1185 op: "Const"
1186 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1187 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1188 attr {
1189 key: "dtype"
1190 value {
1191 type: DT_INT64
1192 }
1193 }
1194 attr {
1195 key: "value"
1196 value {
1197 tensor {
1198 dtype: DT_INT64
1199 tensor_shape {
1200 }
1201 int64_val: 5000
1202 }
1203 }
1204 }
1205 }
1206 node {
1207 name: "EpisodicReplayBuffer/add/ReadVariableOp"
1208 op: "ReadVariableOp"
1209 input: "episodicreplaybuffer_add_readvariableop_resource"
1210 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1211 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1212 attr {
1213 key: "dtype"
1214 value {
1215 type: DT_INT64
1216 }
1217 }
1218 }
1219 node {
1220 name: "EpisodicReplayBuffer/add/Less/y"
1221 op: "Const"
1222 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1223 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1224 attr {
1225 key: "dtype"
1226 value {
1227 type: DT_INT64
1228 }
1229 }
1230 attr {
1231 key: "value"
1232 value {
1233 tensor {
1234 dtype: DT_INT64
1235 tensor_shape {
1236 }
1237 int64_val: 0
1238 }
1239 }
1240 }
1241 }
1242 node {
1243 name: "EpisodicReplayBuffer/add/Less"
1244 op: "Less"
1245 input: "EpisodicReplayBuffer/add/ReadVariableOp"
1246 input: "EpisodicReplayBuffer/add/Less/y"
1247 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1248 attr {
1249 key: "T"
1250 value {
1251 type: DT_INT64
1252 }
1253 }
1254 }
1255 node {
1256 name: "EpisodicReplayBuffer/add/or"
1257 op: "LogicalOr"
1258 input: "EpisodicReplayBuffer/add/begin_episode"
1259 input: "EpisodicReplayBuffer/add/Less"
1260 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1261 }
1262 node {
1263 name: "EpisodicReplayBuffer/add/get_episode_id/pred_id"
1264 op: "Identity"
1265 input: "EpisodicReplayBuffer/add/or"
1266 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1267 attr {
1268 key: "T"
1269 value {
1270 type: DT_BOOL
1271 }
1272 }
1273 }
1274 node {
1275 name: "EpisodicReplayBuffer/add/get_episode_id/Switch"
1276 op: "Switch"
1277 input: "EpisodicReplayBuffer/add/or"
1278 input: "EpisodicReplayBuffer/add/or"
1279 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1280 attr {
1281 key: "T"
1282 value {
1283 type: DT_BOOL
1284 }
1285 }
1286 }
1287 node {
1288 name: "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/AssignVariableOp/Switch"
1289 op: "Switch"
1290 input: "episodicreplaybuffer_add_readvariableop_resource"
1291 input: "EpisodicReplayBuffer/add/get_episode_id/pred_id"
1292 input: "^EpisodicReplayBuffer/add/ReadVariableOp"
1293 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1294 attr {
1295 key: "T"
1296 value {
1297 type: DT_RESOURCE
1298 }
1299 }
1300 attr {
1301 key: "_class"
1302 value {
1303 list {
1304 s: "loc:@EpisodicReplayBuffer/add/ReadVariableOp/resource"
1305 }
1306 }
1307 }
1308 }
1309 node {
1310 name: "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/ReadVariableOp_3"
1311 op: "ReadVariableOp"
1312 input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1313 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1314 attr {
1315 key: "dtype"
1316 value {
1317 type: DT_INT64
1318 }
1319 }
1320 }
1321 library {
1322 }
1323 versions {
1324 producer: 27
1325 }
1326 )EOF";
1327
1328 GrapplerItem item;
1329 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1330 item.fetch = {
1331 "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/"
1332 "ReadVariableOp_3"};
1333
1334 LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1335 GraphDef output;
1336 Status status = optimizer.Optimize(nullptr, item, &output);
1337 TF_CHECK_OK(status);
1338
1339 bool found_merge = false;
1340 for (const auto& node : output.node()) {
1341 if (node.name() ==
1342 "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge") {
1343 found_merge = true;
1344 }
1345 }
1346
1347 EXPECT_TRUE(found_merge)
1348 << "Merge node was deleted, but it shouldn't have been.";
1349 }
1350
TEST_F(LoopOptimizerTest,RemoveDeadBranchesZeroIterWhile)1351 TEST_F(LoopOptimizerTest, RemoveDeadBranchesZeroIterWhile) {
1352 const string gdef_ascii = R"EOF(
1353 node {
1354 name: "Const"
1355 op: "Const"
1356 attr {
1357 key: "dtype"
1358 value {
1359 type: DT_INT32
1360 }
1361 }
1362 attr {
1363 key: "value"
1364 value {
1365 tensor {
1366 dtype: DT_INT32
1367 tensor_shape {
1368 }
1369 int_val: 20
1370 }
1371 }
1372 }
1373 }
1374 node {
1375 name: "while/Enter"
1376 op: "Enter"
1377 input: "Const"
1378 attr {
1379 key: "T"
1380 value {
1381 type: DT_INT32
1382 }
1383 }
1384 attr {
1385 key: "frame_name"
1386 value {
1387 s: "while/while/"
1388 }
1389 }
1390 attr {
1391 key: "is_constant"
1392 value {
1393 b: false
1394 }
1395 }
1396 attr {
1397 key: "parallel_iterations"
1398 value {
1399 i: 1
1400 }
1401 }
1402 }
1403 node {
1404 name: "while/Merge"
1405 op: "Merge"
1406 input: "while/Enter"
1407 input: "while/NextIteration"
1408 attr {
1409 key: "N"
1410 value {
1411 i: 2
1412 }
1413 }
1414 attr {
1415 key: "T"
1416 value {
1417 type: DT_INT32
1418 }
1419 }
1420 }
1421 node {
1422 name: "while/Less/y"
1423 op: "Const"
1424 input: "^while/Merge"
1425 attr {
1426 key: "dtype"
1427 value {
1428 type: DT_INT32
1429 }
1430 }
1431 attr {
1432 key: "value"
1433 value {
1434 tensor {
1435 dtype: DT_INT32
1436 tensor_shape {
1437 }
1438 int_val: 10
1439 }
1440 }
1441 }
1442 }
1443 node {
1444 name: "while/Less"
1445 op: "Less"
1446 input: "while/Merge"
1447 input: "while/Less/y"
1448 attr {
1449 key: "T"
1450 value {
1451 type: DT_INT32
1452 }
1453 }
1454 }
1455 node {
1456 name: "while/LoopCond"
1457 op: "LoopCond"
1458 input: "while/Less"
1459 }
1460 node {
1461 name: "while/Switch"
1462 op: "Switch"
1463 input: "while/Merge"
1464 input: "while/LoopCond"
1465 attr {
1466 key: "T"
1467 value {
1468 type: DT_INT32
1469 }
1470 }
1471 attr {
1472 key: "_class"
1473 value {
1474 list {
1475 s: "loc:@while/Merge"
1476 }
1477 }
1478 }
1479 }
1480 node {
1481 name: "while/Identity"
1482 op: "Identity"
1483 input: "while/Switch:1"
1484 attr {
1485 key: "T"
1486 value {
1487 type: DT_INT32
1488 }
1489 }
1490 }
1491 node {
1492 name: "while/add/y"
1493 op: "Const"
1494 input: "^while/Identity"
1495 attr {
1496 key: "dtype"
1497 value {
1498 type: DT_INT32
1499 }
1500 }
1501 attr {
1502 key: "value"
1503 value {
1504 tensor {
1505 dtype: DT_INT32
1506 tensor_shape {
1507 }
1508 int_val: 1
1509 }
1510 }
1511 }
1512 }
1513 node {
1514 name: "while/add"
1515 op: "Add"
1516 input: "while/Identity"
1517 input: "while/add/y"
1518 attr {
1519 key: "T"
1520 value {
1521 type: DT_INT32
1522 }
1523 }
1524 }
1525 node {
1526 name: "while/NextIteration"
1527 op: "NextIteration"
1528 input: "while/add"
1529 attr {
1530 key: "T"
1531 value {
1532 type: DT_INT32
1533 }
1534 }
1535 }
1536 node {
1537 name: "while/Exit"
1538 op: "Exit"
1539 input: "while/Switch"
1540 attr {
1541 key: "T"
1542 value {
1543 type: DT_INT32
1544 }
1545 }
1546 }
1547 versions {
1548 producer: 21
1549 }
1550 )EOF";
1551
1552 GrapplerItem item;
1553 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1554 item.fetch = {"while/Exit"};
1555 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1556 ASSERT_EQ(tensors_expected.size(), 1);
1557
1558 LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1559 GraphDef output;
1560 Status status = optimizer.Optimize(nullptr, item, &output);
1561 TF_CHECK_OK(status);
1562 auto tensors_got = EvaluateNodes(output, item.fetch);
1563 ASSERT_EQ(tensors_got.size(), 1);
1564 test::ExpectTensorEqual<int32>(tensors_got[0], tensors_expected[0]);
1565
1566 int nodes_present = 0;
1567 for (const NodeDef& node : output.node()) {
1568 // All nodes connected to Switch's positive check should be pruned.
1569 if (node.name() == "while/add") {
1570 LOG(ERROR) << "while/add is present after optimization";
1571 } else if (node.name() == "while/add/y") {
1572 LOG(ERROR) << "while/add/y is present after optimization";
1573 } else if (node.name() == "while/NextIteration") {
1574 LOG(ERROR) << "while/NextIteration is present after optimization";
1575 } else if (node.name() == "while/Identity") {
1576 LOG(ERROR) << "while/Identity is present after optimization";
1577 }
1578 ++nodes_present;
1579 }
1580 EXPECT_EQ(nodes_present, 8);
1581 }
1582
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantFeed)1583 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantFeed) {
1584 const string gdef_ascii = R"EOF(
1585 node {
1586 name: "Const"
1587 op: "Const"
1588 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1589 attr {
1590 key: "dtype"
1591 value {
1592 type: DT_STRING
1593 }
1594 }
1595 attr {
1596 key: "value"
1597 value {
1598 tensor {
1599 dtype: DT_STRING
1600 tensor_shape {
1601 dim {
1602 size: 1
1603 }
1604 }
1605 string_val: "I\'m a value!"
1606 }
1607 }
1608 }
1609 }
1610 node {
1611 name: "cond/Switch_1"
1612 op: "Switch"
1613 input: "Const"
1614 input: "Const_1"
1615 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1616 attr {
1617 key: "T"
1618 value {
1619 type: DT_STRING
1620 }
1621 }
1622 attr {
1623 key: "_class"
1624 value {
1625 list {
1626 s: "loc:@Const"
1627 }
1628 }
1629 }
1630 }
1631 node {
1632 name: "Const_1"
1633 op: "Const"
1634 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1635 attr {
1636 key: "dtype"
1637 value {
1638 type: DT_BOOL
1639 }
1640 }
1641 attr {
1642 key: "value"
1643 value {
1644 tensor {
1645 dtype: DT_BOOL
1646 tensor_shape {
1647 }
1648 bool_val: true
1649 }
1650 }
1651 }
1652 }
1653 node {
1654 name: "cond/Switch"
1655 op: "Switch"
1656 input: "Const_1"
1657 input: "Const_1"
1658 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1659 attr {
1660 key: "T"
1661 value {
1662 type: DT_BOOL
1663 }
1664 }
1665 }
1666 node {
1667 name: "cond/switch_t"
1668 op: "Identity"
1669 input: "cond/Switch:1"
1670 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1671 attr {
1672 key: "T"
1673 value {
1674 type: DT_BOOL
1675 }
1676 }
1677 }
1678 node {
1679 name: "cond/Const"
1680 op: "Const"
1681 input: "^cond/switch_t"
1682 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1683 attr {
1684 key: "dtype"
1685 value {
1686 type: DT_STRING
1687 }
1688 }
1689 attr {
1690 key: "value"
1691 value {
1692 tensor {
1693 dtype: DT_STRING
1694 tensor_shape {
1695 dim {
1696 size: 1
1697 }
1698 }
1699 string_val: ""
1700 }
1701 }
1702 }
1703 }
1704 node {
1705 name: "cond/Merge"
1706 op: "Merge"
1707 input: "cond/Switch_1"
1708 input: "cond/Const"
1709 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1710 attr {
1711 key: "N"
1712 value {
1713 i: 2
1714 }
1715 }
1716 attr {
1717 key: "T"
1718 value {
1719 type: DT_STRING
1720 }
1721 }
1722 }
1723 node {
1724 name: "Identity"
1725 op: "Identity"
1726 input: "cond/Merge"
1727 device: "/job:localhost/replica:0/task:0/device:CPU:0"
1728 attr {
1729 key: "T"
1730 value {
1731 type: DT_STRING
1732 }
1733 }
1734 }
1735 library {
1736 }
1737 versions {
1738 producer: 27
1739 }
1740 )EOF";
1741
1742 GrapplerItem item;
1743 CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1744 item.fetch = {"Identity"};
1745 Tensor feed_tensor(DT_BOOL, {});
1746 feed_tensor.flat<bool>()(0) = false;
1747 item.feed.push_back({"Const_1", feed_tensor});
1748 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1749 ASSERT_EQ(tensors_expected.size(), 1);
1750
1751 LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1752 GraphDef output;
1753 Status status = optimizer.Optimize(nullptr, item, &output);
1754 TF_CHECK_OK(status);
1755 auto tensors_got = EvaluateNodes(output, item.fetch);
1756 ASSERT_EQ(tensors_got.size(), 1);
1757 test::ExpectTensorEqual<tstring>(tensors_got[0], tensors_expected[0]);
1758
1759 EXPECT_EQ(output.node_size(), 8);
1760
1761 // No rewrite because branch has a constant feed node.
1762 bool found = false;
1763 for (const NodeDef& node : output.node()) {
1764 if (node.name() == "cond/Merge") {
1765 EXPECT_EQ(node.op(), "Merge");
1766 ASSERT_EQ(node.input_size(), 2);
1767 EXPECT_EQ(node.input(0), "cond/Switch_1");
1768 EXPECT_EQ(node.input(1), "cond/Const");
1769 found = true;
1770 break;
1771 }
1772 }
1773 EXPECT_TRUE(found);
1774 }
1775
1776 } // namespace grappler
1777 } // namespace tensorflow
1778