xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17 
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
23 #include "tensorflow/core/grappler/utils.h"
24 #include "tensorflow/core/grappler/utils/graph_view.h"
25 #include "tensorflow/core/grappler/utils/grappler_test.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
32 class LoopOptimizerTest : public GrapplerTest {
33  protected:
34   // These helpers always sets T=DT_FLOAT.
AddEnterNode(const string & name,const string & frame,const bool is_constant,const int piterations,const std::vector<string> & inputs,GraphDef * graph) const35   void AddEnterNode(const string& name, const string& frame,
36                     const bool is_constant, const int piterations,
37                     const std::vector<string>& inputs, GraphDef* graph) const {
38     std::vector<std::pair<string, AttrValue>> attributes;
39     AttrValue type;
40     type.set_type(DT_FLOAT);
41     attributes.emplace_back("T", type);
42     AttrValue frame_name;
43     frame_name.set_s(frame);
44     attributes.emplace_back("frame_name", frame_name);
45     AttrValue is_const;
46     is_const.set_b(is_constant);
47     attributes.emplace_back("is_constant", is_const);
48     AttrValue parallel_iterations;
49     parallel_iterations.set_i(piterations);
50     attributes.emplace_back("parallel_iterations", parallel_iterations);
51     AddNode(name, "Enter", inputs, attributes, graph);
52   }
53 
AddSimpleNode(const string & name,const string & op,const std::vector<string> & inputs,GraphDef * graph) const54   void AddSimpleNode(const string& name, const string& op,
55                      const std::vector<string>& inputs, GraphDef* graph) const {
56     std::vector<std::pair<string, AttrValue>> attributes;
57     AttrValue type;
58     type.set_type(DT_FLOAT);
59     attributes.emplace_back("T", type);
60     AddNode(name, op, inputs, attributes, graph);
61   }
62 
EnableOnlyLoopInvariantNodeMotion(LoopOptimizer * optimizer)63   void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
64     DisableAllStages(optimizer);
65     optimizer->options_.enable_loop_invariant_node_motion = true;
66   }
67 
EnableOnlyStackPushRemoval(LoopOptimizer * optimizer)68   void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
69     DisableAllStages(optimizer);
70     optimizer->options_.enable_stack_push_removal = true;
71   }
72 
73  private:
DisableAllStages(LoopOptimizer * optimizer)74   void DisableAllStages(LoopOptimizer* optimizer) {
75     LoopOptimizer::LoopOptimizerOptions options;
76     options.enable_loop_invariant_node_motion = false;
77     options.enable_stack_push_removal = false;
78     optimizer->options_ = options;
79   }
80 };
81 
TEST_F(LoopOptimizerTest,Basic)82 TEST_F(LoopOptimizerTest, Basic) {
83   GraphDef graph;
84   AddSimpleNode("In", "Identity", {}, &graph);
85   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
86                &graph);
87   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
88                 &graph);
89   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
90   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
91   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
92   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
93   AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
94   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
95   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
96   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
97   AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
98   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
99   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
100 
101   GrapplerItem item;
102   item.graph = graph;
103 
104   LoopOptimizer optimizer;
105   EnableOnlyLoopInvariantNodeMotion(&optimizer);
106   GraphDef output;
107   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
108 
109   {  // Original graph.
110     Status status;
111     utils::GraphView view(&graph, &status);
112     TF_ASSERT_OK(status);
113     FrameView frames;
114     TF_EXPECT_OK(frames.InferFromGraphView(view));
115 
116     EXPECT_EQ(frames.num_frames(), 1);
117     const auto* invariant_add_node = view.GetNode("InvariantAdd");
118     ASSERT_NE(invariant_add_node, nullptr);
119     const auto* invariant_add_node_def = invariant_add_node->node();
120     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
121     EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
122     const auto* variant_add_node = view.GetNode("VariantAdd");
123     ASSERT_NE(variant_add_node, nullptr);
124     const auto* variant_add_node_def = variant_add_node->node();
125     ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
126     EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
127   }
128 
129   {  // Optimized graph.
130     Status status;
131     utils::GraphView view(&output, &status);
132     TF_ASSERT_OK(status);
133     FrameView frames;
134     TF_EXPECT_OK(frames.InferFromGraphView(view));
135 
136     EXPECT_EQ(frames.num_frames(), 1);
137     const auto* invariant_add_node = view.GetNode("InvariantAdd");
138     ASSERT_NE(invariant_add_node, nullptr);
139     const auto* invariant_add_node_def = invariant_add_node->node();
140     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
141     const auto* variant_add_node = view.GetNode("VariantAdd");
142     ASSERT_NE(variant_add_node, nullptr);
143     const auto* variant_add_node_def = variant_add_node->node();
144     ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
145     EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
146   }
147 }
148 
TEST_F(LoopOptimizerTest,Const)149 TEST_F(LoopOptimizerTest, Const) {
150   GraphDef graph;
151   AddSimpleNode("In", "Identity", {}, &graph);
152   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
153                &graph);
154   AddSimpleNode("Const", "Const", {"^Identity"}, &graph);
155   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "Const"}, &graph);
156   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
157   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
158   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
159   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
160   AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y"}, &graph);
161   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
162   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
163   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
164   AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
165   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
166   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
167 
168   GrapplerItem item;
169   item.graph = graph;
170 
171   LoopOptimizer optimizer;
172   EnableOnlyLoopInvariantNodeMotion(&optimizer);
173   GraphDef output;
174   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
175 
176   {  // Original graph.
177     Status status;
178     utils::GraphView view(&graph, &status);
179     TF_ASSERT_OK(status);
180     FrameView frames;
181     TF_EXPECT_OK(frames.InferFromGraphView(view));
182 
183     EXPECT_EQ(frames.num_frames(), 1);
184     const auto* invariant_add_node = view.GetNode("InvariantAdd");
185     ASSERT_NE(invariant_add_node, nullptr);
186     const auto* invariant_add_node_def = invariant_add_node->node();
187     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
188     EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
189     const auto* const_node = view.GetNode("Const");
190     ASSERT_NE(const_node, nullptr);
191     const auto* const_node_node_def = const_node->node();
192     ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 1);
193     EXPECT_EQ(frames.Frames(*const_node_node_def).back(), 0);
194   }
195 
196   {  // Optimized graph.
197     Status status;
198     utils::GraphView view(&output, &status);
199     TF_ASSERT_OK(status);
200     FrameView frames;
201     TF_EXPECT_OK(frames.InferFromGraphView(view));
202 
203     EXPECT_EQ(frames.num_frames(), 1);
204     const auto* invariant_add_node = view.GetNode("InvariantAdd");
205     ASSERT_NE(invariant_add_node, nullptr);
206     const auto* invariant_add_node_def = invariant_add_node->node();
207     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
208     const auto* const_node = view.GetNode("Const");
209     ASSERT_NE(const_node, nullptr);
210     const auto* const_node_node_def = const_node->node();
211     ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 0);
212   }
213 }
214 
TEST_F(LoopOptimizerTest,ControlOutput)215 TEST_F(LoopOptimizerTest, ControlOutput) {
216   GraphDef graph;
217   AddSimpleNode("In", "Identity", {}, &graph);
218   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
219                &graph);
220   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
221                 &graph);
222   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
223   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
224   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
225   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
226   AddSimpleNode("Less", "Less", {"VariantAdd", "Less/y", "^InvariantAdd"},
227                 &graph);
228   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
229   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
230   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
231   AddSimpleNode("NextIteration", "NextIteration", {"VariantAdd"}, &graph);
232   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
233   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
234 
235   GrapplerItem item;
236   item.graph = graph;
237 
238   LoopOptimizer optimizer;
239   EnableOnlyLoopInvariantNodeMotion(&optimizer);
240   GraphDef output;
241   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
242 
243   {  // Original graph.
244     Status status;
245     utils::GraphView view(&graph, &status);
246     TF_ASSERT_OK(status);
247     FrameView frames;
248     TF_EXPECT_OK(frames.InferFromGraphView(view));
249 
250     EXPECT_EQ(frames.num_frames(), 1);
251     const auto* invariant_add_node = view.GetNode("InvariantAdd");
252     ASSERT_NE(invariant_add_node, nullptr);
253     const auto* invariant_add_node_def = invariant_add_node->node();
254     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
255     EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
256   }
257 
258   {  // Optimized graph.
259     Status status;
260     utils::GraphView view(&output, &status);
261     TF_ASSERT_OK(status);
262     FrameView frames;
263     TF_EXPECT_OK(frames.InferFromGraphView(view));
264 
265     EXPECT_EQ(frames.num_frames(), 1);
266     const auto* invariant_add_node = view.GetNode("InvariantAdd");
267     ASSERT_NE(invariant_add_node, nullptr);
268     const auto* invariant_add_node_def = invariant_add_node->node();
269     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
270     EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
271   }
272 }
273 
TEST_F(LoopOptimizerTest,NestedLoop1)274 TEST_F(LoopOptimizerTest, NestedLoop1) {
275   GraphDef graph;
276   AddSimpleNode("In", "Identity", {}, &graph);
277   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
278                &graph);
279   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
280                 &graph);
281   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
282   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
283   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
284   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
285   AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
286   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
287   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
288   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
289   AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
290   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
291   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
292 
293   AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
294                {"VariantAdd"}, &graph);
295   AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
296                 &graph);
297   AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
298   AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
299                {"VariantEnter"}, &graph);
300   AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
301   AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
302   AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
303   AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
304   AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
305   AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
306   AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
307   AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
308 
309   GrapplerItem item;
310   item.graph = graph;
311 
312   LoopOptimizer optimizer;
313   EnableOnlyLoopInvariantNodeMotion(&optimizer);
314   GraphDef output;
315   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
316 
317   {  // Original graph.
318     Status status;
319     utils::GraphView view(&graph, &status);
320     TF_ASSERT_OK(status);
321     FrameView frames;
322     TF_EXPECT_OK(frames.InferFromGraphView(view));
323 
324     EXPECT_EQ(frames.num_frames(), 2);
325     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
326     ASSERT_NE(invariant_add_2_node, nullptr);
327     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
328     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
329     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
330     const auto* variant_add_2_node = view.GetNode("VariantAdd2");
331     ASSERT_NE(variant_add_2_node, nullptr);
332     const auto* variant_add_2_node_def = variant_add_2_node->node();
333     ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
334     EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
335     const auto* invariant_add_node = view.GetNode("InvariantAdd");
336     ASSERT_NE(invariant_add_node, nullptr);
337     const auto* invariant_add_node_def = invariant_add_node->node();
338     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
339     EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
340   }
341 
342   {  // Optimized graph.
343     Status status;
344     utils::GraphView view(&output, &status);
345     TF_ASSERT_OK(status);
346     FrameView frames;
347     TF_EXPECT_OK(frames.InferFromGraphView(view));
348 
349     EXPECT_EQ(frames.num_frames(), 2);
350     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
351     ASSERT_NE(invariant_add_2_node, nullptr);
352     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
353     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
354     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
355     const auto* variant_add_2_node = view.GetNode("VariantAdd2");
356     ASSERT_NE(variant_add_2_node, nullptr);
357     const auto* variant_add_2_node_def = variant_add_2_node->node();
358     ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
359     EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
360     const auto* invariant_add_node = view.GetNode("InvariantAdd");
361     ASSERT_NE(invariant_add_node, nullptr);
362     const auto* invariant_add_node_def = invariant_add_node->node();
363     ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
364   }
365 }
366 
TEST_F(LoopOptimizerTest,NestedLoop2)367 TEST_F(LoopOptimizerTest, NestedLoop2) {
368   GraphDef graph;
369   AddSimpleNode("In", "Identity", {}, &graph);
370   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
371                &graph);
372   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
373                 &graph);
374   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
375   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
376   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
377   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
378   AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
379   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
380   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
381   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
382   AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
383   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
384   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
385 
386   AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
387                {"InvariantAdd"}, &graph);
388   AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "InvariantEnter2"},
389                 &graph);
390   AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
391   AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
392                {"VariantEnter"}, &graph);
393   AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
394   AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
395   AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
396   AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
397   AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
398   AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
399   AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
400   AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
401 
402   GrapplerItem item;
403   item.graph = graph;
404 
405   LoopOptimizer optimizer;
406   EnableOnlyLoopInvariantNodeMotion(&optimizer);
407   GraphDef output;
408   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
409 
410   {  // Original graph.
411     Status status;
412     utils::GraphView view(&graph, &status);
413     TF_ASSERT_OK(status);
414     FrameView frames;
415     TF_EXPECT_OK(frames.InferFromGraphView(view));
416 
417     EXPECT_EQ(frames.num_frames(), 2);
418     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
419     ASSERT_NE(invariant_add_2_node, nullptr);
420     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
421     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
422     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
423     const auto* variant_add_2_node = view.GetNode("VariantAdd2");
424     ASSERT_NE(variant_add_2_node, nullptr);
425     const auto* variant_add_2_node_def = variant_add_2_node->node();
426     ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
427     EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
428   }
429 
430   {  // Optimized graph.
431     Status status;
432     utils::GraphView view(&output, &status);
433     TF_ASSERT_OK(status);
434     FrameView frames;
435     TF_EXPECT_OK(frames.InferFromGraphView(view));
436 
437     EXPECT_EQ(frames.num_frames(), 2);
438     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
439     ASSERT_NE(invariant_add_2_node, nullptr);
440     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
441     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
442     const auto* variant_add_2_node = view.GetNode("VariantAdd2");
443     ASSERT_NE(variant_add_2_node, nullptr);
444     const auto* variant_add_2_node_def = variant_add_2_node->node();
445     ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
446     EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
447   }
448 }
449 
TEST_F(LoopOptimizerTest,NestedLoopConst1)450 TEST_F(LoopOptimizerTest, NestedLoopConst1) {
451   GraphDef graph;
452   AddSimpleNode("In", "Identity", {}, &graph);
453   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
454                &graph);
455   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
456                 &graph);
457   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
458   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
459   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
460   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
461   AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
462   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
463   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
464   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
465   AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
466   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
467   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
468 
469   AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
470                {"VariantAdd"}, &graph);
471   AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
472   AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
473   AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
474   AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
475                {"VariantEnter"}, &graph);
476   AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
477   AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
478   AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
479   AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
480   AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
481   AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
482   AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
483   AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
484 
485   GrapplerItem item;
486   item.graph = graph;
487 
488   LoopOptimizer optimizer;
489   EnableOnlyLoopInvariantNodeMotion(&optimizer);
490   GraphDef output;
491   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
492 
493   {  // Original graph.
494     Status status;
495     utils::GraphView view(&graph, &status);
496     TF_ASSERT_OK(status);
497     FrameView frames;
498     TF_EXPECT_OK(frames.InferFromGraphView(view));
499 
500     EXPECT_EQ(frames.num_frames(), 2);
501     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
502     ASSERT_NE(invariant_add_2_node, nullptr);
503     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
504     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
505     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
506     const auto* const_2_node = view.GetNode("Const2");
507     ASSERT_NE(const_2_node, nullptr);
508     const auto* const_2_node_def = const_2_node->node();
509     ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
510     EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
511   }
512 
513   {  // Optimized graph.
514     Status status;
515     utils::GraphView view(&output, &status);
516     TF_ASSERT_OK(status);
517     FrameView frames;
518     TF_EXPECT_OK(frames.InferFromGraphView(view));
519 
520     EXPECT_EQ(frames.num_frames(), 2);
521     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
522     ASSERT_NE(invariant_add_2_node, nullptr);
523     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
524     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
525     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
526     const auto* const_2_node = view.GetNode("Const2");
527     ASSERT_NE(const_2_node, nullptr);
528     const auto* const_2_node_def = const_2_node->node();
529     ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 1);
530     EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 0);
531   }
532 }
533 
TEST_F(LoopOptimizerTest,NestedLoopConst2)534 TEST_F(LoopOptimizerTest, NestedLoopConst2) {
535   GraphDef graph;
536   AddSimpleNode("In", "Identity", {}, &graph);
537   AddEnterNode("InvariantEnter", "while/while_context", true, 1, {"In"},
538                &graph);
539   AddSimpleNode("InvariantAdd", "Add", {"InvariantEnter", "InvariantEnter"},
540                 &graph);
541   AddSimpleNode("VariantAdd", "Add", {"InvariantAdd", "Identity"}, &graph);
542   AddEnterNode("VariantEnter", "while/while_context", false, 1, {"In"}, &graph);
543   AddSimpleNode("Merge", "Merge", {"VariantEnter", "NextIteration"}, &graph);
544   AddSimpleNode("Less/y", "Const", {"^Identity"}, &graph);
545   AddSimpleNode("Less", "Less", {"Exit2", "Less/y"}, &graph);
546   AddSimpleNode("LoopCond", "LoopCond", {"Less"}, &graph);
547   AddSimpleNode("Switch", "Switch", {"Merge", "LoopCond"}, &graph);
548   AddSimpleNode("Identity", "Identity", {"Switch:1"}, &graph);
549   AddSimpleNode("NextIteration", "NextIteration", {"Exit2"}, &graph);
550   AddSimpleNode("Exit", "Exit", {"Switch"}, &graph);
551   AddSimpleNode("Out", "Identity", {"Exit"}, &graph);
552 
553   AddEnterNode("InvariantEnter2", "while/while/while_context", true, 1,
554                {"InvariantAdd"}, &graph);
555   AddSimpleNode("Const2", "Const", {"^Identity2"}, &graph);
556   AddSimpleNode("InvariantAdd2", "Add", {"InvariantEnter2", "Const2"}, &graph);
557   AddSimpleNode("VariantAdd2", "Add", {"InvariantAdd2", "Identity2"}, &graph);
558   AddEnterNode("VariantEnter2", "while/while/while_context", false, 1,
559                {"VariantEnter"}, &graph);
560   AddSimpleNode("Merge2", "Merge", {"VariantEnter2", "NextIteration2"}, &graph);
561   AddSimpleNode("Less2/y", "Const", {"^Identity2"}, &graph);
562   AddSimpleNode("Less2", "Less", {"VariantAdd2", "Less2/y"}, &graph);
563   AddSimpleNode("LoopCond2", "LoopCond", {"Less2"}, &graph);
564   AddSimpleNode("Switch2", "Switch", {"Merge2", "LoopCond2"}, &graph);
565   AddSimpleNode("Identity2", "Identity", {"Switch2:1"}, &graph);
566   AddSimpleNode("NextIteration2", "NextIteration", {"VariantAdd2"}, &graph);
567   AddSimpleNode("Exit2", "Exit", {"Switch2"}, &graph);
568 
569   GrapplerItem item;
570   item.graph = graph;
571 
572   LoopOptimizer optimizer;
573   EnableOnlyLoopInvariantNodeMotion(&optimizer);
574   GraphDef output;
575   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
576 
577   {  // Original graph.
578     Status status;
579     utils::GraphView view(&graph, &status);
580     TF_ASSERT_OK(status);
581     FrameView frames;
582     TF_EXPECT_OK(frames.InferFromGraphView(view));
583 
584     EXPECT_EQ(frames.num_frames(), 2);
585     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
586     ASSERT_NE(invariant_add_2_node, nullptr);
587     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
588     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
589     EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
590     const auto* const_2_node = view.GetNode("Const2");
591     ASSERT_NE(const_2_node, nullptr);
592     const auto* const_2_node_def = const_2_node->node();
593     ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
594     EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
595   }
596 
597   {  // Optimized graph.
598     Status status;
599     utils::GraphView view(&output, &status);
600     TF_ASSERT_OK(status);
601     FrameView frames;
602     TF_EXPECT_OK(frames.InferFromGraphView(view));
603 
604     EXPECT_EQ(frames.num_frames(), 2);
605     const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
606     ASSERT_NE(invariant_add_2_node, nullptr);
607     const auto* invariant_add_2_node_def = invariant_add_2_node->node();
608     ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
609     const auto* const_2_node = view.GetNode("Const2");
610     ASSERT_NE(const_2_node, nullptr);
611     const auto* const_2_node_def = const_2_node->node();
612     ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 0);
613   }
614 }
615 
VerifyGraphsEqual(const GraphDef & original_graph,const GraphDef & optimized_graph,const string & func)616 void VerifyGraphsEqual(const GraphDef& original_graph,
617                        const GraphDef& optimized_graph, const string& func) {
618   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
619   for (int i = 0; i < original_graph.node_size(); ++i) {
620     const NodeDef& original = original_graph.node(i);
621     const NodeDef& optimized = optimized_graph.node(i);
622     EXPECT_EQ(optimized.name(), original.name()) << func;
623     EXPECT_EQ(optimized.op(), original.op()) << func;
624     ASSERT_EQ(optimized.input_size(), original.input_size()) << func;
625     for (int j = 0; j < original.input_size(); ++j) {
626       EXPECT_EQ(optimized.input(j), original.input(j)) << func;
627     }
628   }
629 }
630 
TEST_F(LoopOptimizerTest,NoOp)631 TEST_F(LoopOptimizerTest, NoOp) {
632   // This trivial graph is so basic there's nothing to optimize.
633   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
634   GrapplerItem item;
635   CHECK(fake_input.NextItem(&item));
636 
637   LoopOptimizer optimizer;
638   EnableOnlyStackPushRemoval(&optimizer);
639   GraphDef output;
640   Status status = optimizer.Optimize(nullptr, item, &output);
641   TF_EXPECT_OK(status);
642 
643   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
644 }
645 
TEST_F(LoopOptimizerTest,RemovePushNoOp)646 TEST_F(LoopOptimizerTest, RemovePushNoOp) {
647   GrapplerItem item;
648   GraphDef& graph = item.graph;
649   AddSimpleNode("c", "Const", {}, &graph);
650   // Stack with corresponding push/pop.
651   AddSimpleNode("stack1", "StackV2", {}, &graph);
652   AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
653   AddSimpleNode("pop1", "StackPopV2", {"stack1"}, &graph);
654   AddSimpleNode("id1", "Identity", {"pop1"}, &graph);
655   // Stack with corresponding push/pop behind Enter.
656   AddSimpleNode("stack2", "StackV2", {}, &graph);
657   AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
658   AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
659   AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
660   AddSimpleNode("pop2", "StackPopV2", {"enter2_stack2"}, &graph);
661   AddSimpleNode("id2", "Identity", {"pop2"}, &graph);
662   // Stack with unexpected op type in fanout of Stack.
663   AddSimpleNode("stack3", "StackV2", {}, &graph);
664   AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
665   AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
666 
667   LoopOptimizer optimizer;
668   EnableOnlyStackPushRemoval(&optimizer);
669   GraphDef output;
670   Status status = optimizer.Optimize(nullptr, item, &output);
671   TF_EXPECT_OK(status);
672   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
673 }
674 
TEST_F(LoopOptimizerTest,RemovePushNoPopButStackLives)675 TEST_F(LoopOptimizerTest, RemovePushNoPopButStackLives) {
676   GrapplerItem item;
677   GraphDef& graph = item.graph;
678   AddSimpleNode("c", "Const", {}, &graph);
679   // Stack with corresponding push
680   AddSimpleNode("stack1", "StackV2", {}, &graph);
681   AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
682   // Stack with corresponding push behind Enter.
683   AddSimpleNode("stack2", "StackV2", {}, &graph);
684   AddEnterNode("enter2_c", "frame_name", false, 1, {"c"}, &graph);
685   AddEnterNode("enter2_stack2", "frame_name", false, 1, {"stack2"}, &graph);
686   AddSimpleNode("push2", "StackPushV2", {"enter2_stack2", "enter2_c"}, &graph);
687   item.keep_ops.push_back("stack1");
688   item.keep_ops.push_back("stack2");
689 
690   LoopOptimizer optimizer;
691   EnableOnlyStackPushRemoval(&optimizer);
692   GraphDef output;
693   Status status = optimizer.Optimize(nullptr, item, &output);
694   TF_EXPECT_OK(status);
695   VerifyGraphsEqual(item.graph, output, __FUNCTION__);
696 }
697 
TEST_F(LoopOptimizerTest,RemovePushWithoutMatchingPop)698 TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
699   GrapplerItem item;
700   GraphDef& graph = item.graph;
701   AddSimpleNode("c", "Const", {}, &graph);
702   // Push without Pop.
703   AddSimpleNode("stack1", "StackV2", {}, &graph);
704   AddSimpleNode("push1", "StackPushV2", {"stack1", "c"}, &graph);
705   // Push without Pop behind Enter.
706   AddSimpleNode("stack2", "StackV2", {}, &graph);
707   AddEnterNode("enter_c", "frame_name", false, 1, {"c"}, &graph);
708   AddEnterNode("enter_stack2", "frame_name", false, 1, {"stack2"}, &graph);
709   AddSimpleNode("push2", "StackPushV2", {"enter_stack2", "enter_c"}, &graph);
710   // Pop without consumer.
711   AddSimpleNode("stack3", "StackV2", {}, &graph);
712   AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph);
713   AddSimpleNode("pop3", "StackPopV2", {"stack3"}, &graph);
714   // Push for a Pop without consumer that is fetched should not be removed.
715   AddSimpleNode("stack4", "StackV2", {}, &graph);
716   AddSimpleNode("push4", "StackPushV2", {"stack4", "c"}, &graph);
717   AddSimpleNode("pop4", "StackPopV2", {"stack4"}, &graph);
718 
719   item.fetch.push_back("pop4");
720 
721   LoopOptimizer optimizer;
722   EnableOnlyStackPushRemoval(&optimizer);
723   GraphDef output;
724   Status status = optimizer.Optimize(nullptr, item, &output);
725   TF_EXPECT_OK(status);
726 
727   EXPECT_EQ(output.node_size(), 13);
728   for (int i = 0; i < output.node_size(); ++i) {
729     const NodeDef& node = output.node(i);
730     if (node.name() == "push1") {
731       EXPECT_EQ(node.op(), "Identity");
732       ASSERT_EQ(node.input_size(), 2);
733       EXPECT_EQ(node.input(0), "c");
734       EXPECT_EQ(node.input(1), "^stack1");
735     } else if (node.name() == "push2") {
736       EXPECT_EQ(node.op(), "Identity");
737       ASSERT_EQ(node.input_size(), 2);
738       EXPECT_EQ(node.input(0), "enter_c");
739       EXPECT_EQ(node.input(1), "^enter_stack2");
740     } else if (node.name() == "push3") {
741       EXPECT_EQ(node.op(), "Identity");
742       ASSERT_EQ(node.input_size(), 2);
743       EXPECT_EQ(node.input(0), "c");
744       EXPECT_EQ(node.input(1), "^stack3");
745     } else {
746       const NodeDef& orig_node = item.graph.node(i);
747       EXPECT_EQ(node.ShortDebugString(), orig_node.ShortDebugString());
748     }
749   }
750 }
751 
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantCondition)752 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
753   Scope scope = Scope::NewRootScope();
754   Output v_in = ops::Const<float>(scope.WithOpName("v_in"), {123.0}, {});
755 
756   Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({}));
757   ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
758   Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
759   Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true);
760 
761   Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({}));
762   ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2);
763   Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false);
764   Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true);
765 
766   Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({}));
767   ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3);
768   Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false);
769   Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true);
770 
771   Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({}));
772   ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4);
773   Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false);
774   Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true);
775 
776   ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1});
777   ops::Merge m2(scope.WithOpName("m2"), {v_in, square1});
778   ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
779   ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
780   ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
781 
782   ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
783   Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
784   Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true);
785   ops::Merge m8(scope.WithOpName("m8"), {id1, id2});
786 
787   ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1);
788   Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false);
789   Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true);
790   ops::Merge m9(scope.WithOpName("m9"), {id3, id4});
791 
792   GrapplerItem item;
793   item.fetch.push_back("m8");
794   item.fetch.push_back("id4");
795 
796   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
797 
798   LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
799   GraphDef output;
800   Status status = optimizer.Optimize(nullptr, item, &output);
801   TF_CHECK_OK(status);
802 
803   for (const NodeDef& node : output.node()) {
804     // These nodes should have been pruned
805     EXPECT_NE(node.name(), "Square1");
806     EXPECT_NE(node.name(), "Sqrt2");
807     EXPECT_NE(node.name(), "m5");
808 
809     if (node.name() == "m1") {
810       // sqrt1 is dead
811       EXPECT_EQ(node.op(), "Identity");
812       ASSERT_EQ(node.input_size(), 1);
813       EXPECT_EQ(node.input(0), "square1");
814     } else if (node.name() == "m2") {
815       // both inputs are alive
816       EXPECT_EQ(node.op(), "Merge");
817       ASSERT_EQ(node.input_size(), 2);
818       EXPECT_EQ(node.input(0), "v_in");
819       EXPECT_EQ(node.input(1), "square1");
820     } else if (node.name() == "m3") {
821       // sqrt1 is dead
822       EXPECT_EQ(node.op(), "Identity");
823       ASSERT_EQ(node.input_size(), 1);
824       EXPECT_EQ(node.input(0), "v_in");
825     } else if (node.name() == "m4") {
826       // both inputs are alive
827       EXPECT_EQ(node.op(), "Merge");
828       ASSERT_EQ(node.input_size(), 2);
829       EXPECT_EQ(node.input(0), "square1");
830       EXPECT_EQ(node.input(1), "sqrt2");
831     } else if (node.name() == "m8") {
832       // The node is to be preserved because of a fetch
833       EXPECT_EQ(node.op(), "Merge");
834       ASSERT_EQ(node.input_size(), 2);
835       EXPECT_EQ(node.input(0), "id1");
836       EXPECT_EQ(node.input(1), "id2");
837     } else if (node.name() == "m9") {
838       // The node is to be preserved because of a fetch
839       EXPECT_EQ(node.op(), "Merge");
840       ASSERT_EQ(2, node.input_size());
841       EXPECT_EQ(node.input(0), "id3");
842       EXPECT_EQ(node.input(1), "id4");
843     } else if (node.name() == "switch1") {
844       // The node can be replaced by Identity with control_dependency
845       EXPECT_EQ(node.op(), "Identity");
846       ASSERT_EQ(node.input_size(), 2);
847       EXPECT_EQ(node.input(0), "v_in");
848       EXPECT_EQ(node.input(1), "^ctrl1");
849     } else if (node.name() == "switch2") {
850       // The node can be replaced by Identity with control_dependency
851       EXPECT_EQ(node.op(), "Identity");
852       ASSERT_EQ(node.input_size(), 2);
853       EXPECT_EQ(node.input(0), "v_in");
854       EXPECT_EQ(node.input(1), "^ctrl2");
855     } else if (node.name() == "switch3") {
856       // The node can be replaced by Identity with control_dependency
857       EXPECT_EQ(node.op(), "Identity");
858       ASSERT_EQ(node.input_size(), 2);
859       EXPECT_EQ(node.input(0), "v_in");
860       EXPECT_EQ(node.input(1), "^ctrl3");
861     } else if (node.name() == "switch4") {
862       // The node can be replaced by Identity with control_dependency
863       EXPECT_EQ(node.op(), "Identity");
864       ASSERT_EQ(node.input_size(), 2);
865       EXPECT_EQ(node.input(0), "v_in");
866       EXPECT_EQ(node.input(1), "^ctrl4");
867     } else if (node.name() == "switch5") {
868       // The node should remain unchanged
869       EXPECT_EQ(node.op(), "Switch");
870       ASSERT_EQ(node.input_size(), 2);
871       EXPECT_EQ(node.input(0), "v_in");
872       EXPECT_EQ(node.input(1), "ctrl1");
873     } else if (node.name() == "switch6") {
874       // The node should remain unchanged
875       EXPECT_EQ(node.op(), "Switch");
876       ASSERT_EQ(node.input_size(), 2);
877       EXPECT_EQ(node.input(0), "v_in");
878       EXPECT_EQ(node.input(1), "ctrl1");
879     }
880   }
881 
882   auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
883   ASSERT_EQ(tensors_expected.size(), 2);
884 
885   auto tensors = EvaluateNodes(output, {"m8", "m9"});
886   ASSERT_EQ(tensors.size(), 2);
887 
888   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
889   test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
890 }
891 
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantCondition2)892 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition2) {
893   Scope scope = Scope::NewRootScope();
894   Output v_in = ops::Const<float>(scope.WithOpName("v_in"), {123.0}, {});
895 
896   Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), true, TensorShape({}));
897   ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
898 
899   Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
900 
901   Output add1 =
902       ops::Add(scope.WithOpName("add1"), s1.output_true, s1.output_true);
903 
904   Output const2 = ops::Const<float>(scope.WithOpName("const2"), {20.0}, {});
905   Output add2 = ops::Add(scope.WithOpName("add2"), s1.output_true, const2);
906 
907   Output sub1 = ops::Sub(scope.WithOpName("sub1"), add1, add2);
908 
909   ops::Merge m1(scope.WithOpName("m1"), {square1, sub1});
910   Output add3 = ops::Add(scope.WithOpName("add3"), m1.output, const2);
911 
912   GrapplerItem item;
913   item.fetch.push_back("add3");
914 
915   TF_CHECK_OK(scope.ToGraphDef(&item.graph));
916 
917   LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
918   GraphDef output;
919   Status status = optimizer.Optimize(nullptr, item, &output);
920   TF_CHECK_OK(status);
921 
922   for (const NodeDef& node : output.node()) {
923     // This node should have been pruned
924     EXPECT_NE(node.name(), "Square1");
925 
926     if (node.name() == "m1") {
927       // square1 is dead
928       EXPECT_EQ(node.op(), "Identity");
929       ASSERT_EQ(node.input_size(), 1);
930       EXPECT_EQ(node.input(0), "sub1");
931     } else if (node.name() == "switch1") {
932       EXPECT_EQ(node.op(), "Identity");
933       ASSERT_EQ(node.input_size(), 2);
934       EXPECT_EQ(node.input(0), "v_in");
935       EXPECT_EQ(node.input(1), "^ctrl1");
936     }
937   }
938 }
939 
TEST_F(LoopOptimizerTest,RemoveDeadBranchesFullyRemoveDeadBranches)940 TEST_F(LoopOptimizerTest, RemoveDeadBranchesFullyRemoveDeadBranches) {
941   const string gdef_ascii = R"EOF(
942 node {
943   name: "episodicreplaybuffer_add_readvariableop_resource"
944   op: "_Arg"
945   device: "/job:localhost/replica:0/task:0/device:CPU:0"
946   attr {
947     key: "T"
948     value {
949       type: DT_RESOURCE
950     }
951   }
952   attr {
953     key: "index"
954     value {
955       i: 0
956     }
957   }
958 }
959 node {
960   name: "EpisodicReplayBuffer/add/and_1/x"
961   op: "Const"
962   device: "/job:localhost/replica:0/task:0/device:CPU:0"
963   attr {
964     key: "dtype"
965     value {
966       type: DT_BOOL
967     }
968   }
969   attr {
970     key: "value"
971     value {
972       tensor {
973         dtype: DT_BOOL
974         tensor_shape {
975         }
976         bool_val: true
977       }
978     }
979   }
980 }
981 node {
982   name: "EpisodicReplayBuffer/add/begin_episode"
983   op: "Const"
984   device: "/job:localhost/replica:0/task:0/device:CPU:0"
985   attr {
986     key: "dtype"
987     value {
988       type: DT_BOOL
989     }
990   }
991   attr {
992     key: "value"
993     value {
994       tensor {
995         dtype: DT_BOOL
996         tensor_shape {
997         }
998         bool_val: false
999       }
1000     }
1001   }
1002 }
1003 node {
1004   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Switch"
1005   op: "Switch"
1006   input: "EpisodicReplayBuffer/add/and_1/x"
1007   input: "EpisodicReplayBuffer/add/and_1/x"
1008   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1009   attr {
1010     key: "T"
1011     value {
1012       type: DT_BOOL
1013     }
1014   }
1015 }
1016 node {
1017   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/NoOp"
1018   op: "NoOp"
1019   input: "^EpisodicReplayBuffer/add/and_1/x"
1020   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1021 }
1022 node {
1023   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch"
1024   op: "Switch"
1025   input: "EpisodicReplayBuffer/add/and_1/x"
1026   input: "EpisodicReplayBuffer/add/and_1/x"
1027   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1028   attr {
1029     key: "T"
1030     value {
1031       type: DT_BOOL
1032     }
1033   }
1034   attr {
1035     key: "_class"
1036     value {
1037       list {
1038         s: "loc:@EpisodicReplayBuffer/add/assert_equal/All"
1039       }
1040     }
1041   }
1042 }
1043 node {
1044   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_1"
1045   op: "Switch"
1046   input: "EpisodicReplayBuffer/add/begin_episode"
1047   input: "EpisodicReplayBuffer/add/and_1/x"
1048   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1049   attr {
1050     key: "T"
1051     value {
1052       type: DT_BOOL
1053     }
1054   }
1055   attr {
1056     key: "_class"
1057     value {
1058       list {
1059         s: "loc:@EpisodicReplayBuffer/add/begin_episode"
1060       }
1061     }
1062   }
1063 }
1064 node {
1065   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_2"
1066   op: "Switch"
1067   input: "EpisodicReplayBuffer/add/begin_episode"
1068   input: "EpisodicReplayBuffer/add/and_1/x"
1069   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1070   attr {
1071     key: "T"
1072     value {
1073       type: DT_BOOL
1074     }
1075   }
1076   attr {
1077     key: "_class"
1078     value {
1079       list {
1080         s: "loc:@EpisodicReplayBuffer/add/end_episode"
1081       }
1082     }
1083   }
1084 }
1085 node {
1086   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1087   op: "Identity"
1088   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Switch"
1089   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1090   attr {
1091     key: "T"
1092     value {
1093       type: DT_BOOL
1094     }
1095   }
1096 }
1097 node {
1098   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
1099   op: "Const"
1100   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/NoOp"
1101   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1102   attr {
1103     key: "dtype"
1104     value {
1105       type: DT_BOOL
1106     }
1107   }
1108   attr {
1109     key: "value"
1110     value {
1111       tensor {
1112         dtype: DT_BOOL
1113         tensor_shape {
1114         }
1115         tensor_content: "\001"
1116       }
1117     }
1118   }
1119 }
1120 node {
1121   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
1122   op: "Assert"
1123   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch"
1124   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_1"
1125   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert/Switch_2"
1126   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1127   attr {
1128     key: "T"
1129     value {
1130       list {
1131         type: DT_BOOL
1132         type: DT_BOOL
1133       }
1134     }
1135   }
1136   attr {
1137     key: "summarize"
1138     value {
1139       i: 3
1140     }
1141   }
1142 }
1143 node {
1144   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
1145   op: "Identity"
1146   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1147   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
1148   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1149   attr {
1150     key: "T"
1151     value {
1152       type: DT_BOOL
1153     }
1154   }
1155   attr {
1156     key: "_class"
1157     value {
1158       list {
1159         s: "loc:@EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/switch_f"
1160       }
1161     }
1162   }
1163 }
1164 node {
1165   name: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1166   op: "Merge"
1167   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
1168   input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
1169   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1170   attr {
1171     key: "N"
1172     value {
1173       i: 2
1174     }
1175   }
1176   attr {
1177     key: "T"
1178     value {
1179       type: DT_BOOL
1180     }
1181   }
1182 }
1183 node {
1184   name: "EpisodicReplayBuffer/add/FloorMod/y"
1185   op: "Const"
1186   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1187   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1188   attr {
1189     key: "dtype"
1190     value {
1191       type: DT_INT64
1192     }
1193   }
1194   attr {
1195     key: "value"
1196     value {
1197       tensor {
1198         dtype: DT_INT64
1199         tensor_shape {
1200         }
1201         int64_val: 5000
1202       }
1203     }
1204   }
1205 }
1206 node {
1207   name: "EpisodicReplayBuffer/add/ReadVariableOp"
1208   op: "ReadVariableOp"
1209   input: "episodicreplaybuffer_add_readvariableop_resource"
1210   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1211   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1212   attr {
1213     key: "dtype"
1214     value {
1215       type: DT_INT64
1216     }
1217   }
1218 }
1219 node {
1220   name: "EpisodicReplayBuffer/add/Less/y"
1221   op: "Const"
1222   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1223   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1224   attr {
1225     key: "dtype"
1226     value {
1227       type: DT_INT64
1228     }
1229   }
1230   attr {
1231     key: "value"
1232     value {
1233       tensor {
1234         dtype: DT_INT64
1235         tensor_shape {
1236         }
1237         int64_val: 0
1238       }
1239     }
1240   }
1241 }
1242 node {
1243   name: "EpisodicReplayBuffer/add/Less"
1244   op: "Less"
1245   input: "EpisodicReplayBuffer/add/ReadVariableOp"
1246   input: "EpisodicReplayBuffer/add/Less/y"
1247   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1248   attr {
1249     key: "T"
1250     value {
1251       type: DT_INT64
1252     }
1253   }
1254 }
1255 node {
1256   name: "EpisodicReplayBuffer/add/or"
1257   op: "LogicalOr"
1258   input: "EpisodicReplayBuffer/add/begin_episode"
1259   input: "EpisodicReplayBuffer/add/Less"
1260   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1261 }
1262 node {
1263   name: "EpisodicReplayBuffer/add/get_episode_id/pred_id"
1264   op: "Identity"
1265   input: "EpisodicReplayBuffer/add/or"
1266   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1267   attr {
1268     key: "T"
1269     value {
1270       type: DT_BOOL
1271     }
1272   }
1273 }
1274 node {
1275   name: "EpisodicReplayBuffer/add/get_episode_id/Switch"
1276   op: "Switch"
1277   input: "EpisodicReplayBuffer/add/or"
1278   input: "EpisodicReplayBuffer/add/or"
1279   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1280   attr {
1281     key: "T"
1282     value {
1283       type: DT_BOOL
1284     }
1285   }
1286 }
1287 node {
1288   name: "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/AssignVariableOp/Switch"
1289   op: "Switch"
1290   input: "episodicreplaybuffer_add_readvariableop_resource"
1291   input: "EpisodicReplayBuffer/add/get_episode_id/pred_id"
1292   input: "^EpisodicReplayBuffer/add/ReadVariableOp"
1293   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1294   attr {
1295     key: "T"
1296     value {
1297       type: DT_RESOURCE
1298     }
1299   }
1300   attr {
1301     key: "_class"
1302     value {
1303       list {
1304         s: "loc:@EpisodicReplayBuffer/add/ReadVariableOp/resource"
1305       }
1306     }
1307   }
1308 }
1309 node {
1310   name: "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/ReadVariableOp_3"
1311   op: "ReadVariableOp"
1312   input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge"
1313   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1314   attr {
1315     key: "dtype"
1316     value {
1317       type: DT_INT64
1318     }
1319   }
1320 }
1321 library {
1322 }
1323 versions {
1324   producer: 27
1325 }
1326   )EOF";
1327 
1328   GrapplerItem item;
1329   CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1330   item.fetch = {
1331       "EpisodicReplayBuffer/add/get_episode_id/critical_section_execute/"
1332       "ReadVariableOp_3"};
1333 
1334   LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1335   GraphDef output;
1336   Status status = optimizer.Optimize(nullptr, item, &output);
1337   TF_CHECK_OK(status);
1338 
1339   bool found_merge = false;
1340   for (const auto& node : output.node()) {
1341     if (node.name() ==
1342         "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Merge") {
1343       found_merge = true;
1344     }
1345   }
1346 
1347   EXPECT_TRUE(found_merge)
1348       << "Merge node was deleted, but it shouldn't have been.";
1349 }
1350 
TEST_F(LoopOptimizerTest,RemoveDeadBranchesZeroIterWhile)1351 TEST_F(LoopOptimizerTest, RemoveDeadBranchesZeroIterWhile) {
1352   const string gdef_ascii = R"EOF(
1353 node {
1354   name: "Const"
1355   op: "Const"
1356   attr {
1357     key: "dtype"
1358     value {
1359       type: DT_INT32
1360     }
1361   }
1362   attr {
1363     key: "value"
1364     value {
1365       tensor {
1366         dtype: DT_INT32
1367         tensor_shape {
1368         }
1369         int_val: 20
1370       }
1371     }
1372   }
1373 }
1374 node {
1375   name: "while/Enter"
1376   op: "Enter"
1377   input: "Const"
1378   attr {
1379     key: "T"
1380     value {
1381       type: DT_INT32
1382     }
1383   }
1384   attr {
1385     key: "frame_name"
1386     value {
1387       s: "while/while/"
1388     }
1389   }
1390   attr {
1391     key: "is_constant"
1392     value {
1393       b: false
1394     }
1395   }
1396   attr {
1397     key: "parallel_iterations"
1398     value {
1399       i: 1
1400     }
1401   }
1402 }
1403 node {
1404   name: "while/Merge"
1405   op: "Merge"
1406   input: "while/Enter"
1407   input: "while/NextIteration"
1408   attr {
1409     key: "N"
1410     value {
1411       i: 2
1412     }
1413   }
1414   attr {
1415     key: "T"
1416     value {
1417       type: DT_INT32
1418     }
1419   }
1420 }
1421 node {
1422   name: "while/Less/y"
1423   op: "Const"
1424   input: "^while/Merge"
1425   attr {
1426     key: "dtype"
1427     value {
1428       type: DT_INT32
1429     }
1430   }
1431   attr {
1432     key: "value"
1433     value {
1434       tensor {
1435         dtype: DT_INT32
1436         tensor_shape {
1437         }
1438         int_val: 10
1439       }
1440     }
1441   }
1442 }
1443 node {
1444   name: "while/Less"
1445   op: "Less"
1446   input: "while/Merge"
1447   input: "while/Less/y"
1448   attr {
1449     key: "T"
1450     value {
1451       type: DT_INT32
1452     }
1453   }
1454 }
1455 node {
1456   name: "while/LoopCond"
1457   op: "LoopCond"
1458   input: "while/Less"
1459 }
1460 node {
1461   name: "while/Switch"
1462   op: "Switch"
1463   input: "while/Merge"
1464   input: "while/LoopCond"
1465   attr {
1466     key: "T"
1467     value {
1468       type: DT_INT32
1469     }
1470   }
1471   attr {
1472     key: "_class"
1473     value {
1474       list {
1475         s: "loc:@while/Merge"
1476       }
1477     }
1478   }
1479 }
1480 node {
1481   name: "while/Identity"
1482   op: "Identity"
1483   input: "while/Switch:1"
1484   attr {
1485     key: "T"
1486     value {
1487       type: DT_INT32
1488     }
1489   }
1490 }
1491 node {
1492   name: "while/add/y"
1493   op: "Const"
1494   input: "^while/Identity"
1495   attr {
1496     key: "dtype"
1497     value {
1498       type: DT_INT32
1499     }
1500   }
1501   attr {
1502     key: "value"
1503     value {
1504       tensor {
1505         dtype: DT_INT32
1506         tensor_shape {
1507         }
1508         int_val: 1
1509       }
1510     }
1511   }
1512 }
1513 node {
1514   name: "while/add"
1515   op: "Add"
1516   input: "while/Identity"
1517   input: "while/add/y"
1518   attr {
1519     key: "T"
1520     value {
1521       type: DT_INT32
1522     }
1523   }
1524 }
1525 node {
1526   name: "while/NextIteration"
1527   op: "NextIteration"
1528   input: "while/add"
1529   attr {
1530     key: "T"
1531     value {
1532       type: DT_INT32
1533     }
1534   }
1535 }
1536 node {
1537   name: "while/Exit"
1538   op: "Exit"
1539   input: "while/Switch"
1540   attr {
1541     key: "T"
1542     value {
1543       type: DT_INT32
1544     }
1545   }
1546 }
1547 versions {
1548   producer: 21
1549 }
1550   )EOF";
1551 
1552   GrapplerItem item;
1553   CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1554   item.fetch = {"while/Exit"};
1555   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1556   ASSERT_EQ(tensors_expected.size(), 1);
1557 
1558   LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1559   GraphDef output;
1560   Status status = optimizer.Optimize(nullptr, item, &output);
1561   TF_CHECK_OK(status);
1562   auto tensors_got = EvaluateNodes(output, item.fetch);
1563   ASSERT_EQ(tensors_got.size(), 1);
1564   test::ExpectTensorEqual<int32>(tensors_got[0], tensors_expected[0]);
1565 
1566   int nodes_present = 0;
1567   for (const NodeDef& node : output.node()) {
1568     // All nodes connected to Switch's positive check should be pruned.
1569     if (node.name() == "while/add") {
1570       LOG(ERROR) << "while/add is present after optimization";
1571     } else if (node.name() == "while/add/y") {
1572       LOG(ERROR) << "while/add/y is present after optimization";
1573     } else if (node.name() == "while/NextIteration") {
1574       LOG(ERROR) << "while/NextIteration is present after optimization";
1575     } else if (node.name() == "while/Identity") {
1576       LOG(ERROR) << "while/Identity is present after optimization";
1577     }
1578     ++nodes_present;
1579   }
1580   EXPECT_EQ(nodes_present, 8);
1581 }
1582 
TEST_F(LoopOptimizerTest,RemoveDeadBranchesConstantFeed)1583 TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantFeed) {
1584   const string gdef_ascii = R"EOF(
1585 node {
1586   name: "Const"
1587   op: "Const"
1588   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1589   attr {
1590     key: "dtype"
1591     value {
1592       type: DT_STRING
1593     }
1594   }
1595   attr {
1596     key: "value"
1597     value {
1598       tensor {
1599         dtype: DT_STRING
1600         tensor_shape {
1601           dim {
1602             size: 1
1603           }
1604         }
1605         string_val: "I\'m a value!"
1606       }
1607     }
1608   }
1609 }
1610 node {
1611   name: "cond/Switch_1"
1612   op: "Switch"
1613   input: "Const"
1614   input: "Const_1"
1615   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1616   attr {
1617     key: "T"
1618     value {
1619       type: DT_STRING
1620     }
1621   }
1622   attr {
1623     key: "_class"
1624     value {
1625       list {
1626         s: "loc:@Const"
1627       }
1628     }
1629   }
1630 }
1631 node {
1632   name: "Const_1"
1633   op: "Const"
1634   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1635   attr {
1636     key: "dtype"
1637     value {
1638       type: DT_BOOL
1639     }
1640   }
1641   attr {
1642     key: "value"
1643     value {
1644       tensor {
1645         dtype: DT_BOOL
1646         tensor_shape {
1647         }
1648         bool_val: true
1649       }
1650     }
1651   }
1652 }
1653 node {
1654   name: "cond/Switch"
1655   op: "Switch"
1656   input: "Const_1"
1657   input: "Const_1"
1658   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1659   attr {
1660     key: "T"
1661     value {
1662       type: DT_BOOL
1663     }
1664   }
1665 }
1666 node {
1667   name: "cond/switch_t"
1668   op: "Identity"
1669   input: "cond/Switch:1"
1670   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1671   attr {
1672     key: "T"
1673     value {
1674       type: DT_BOOL
1675     }
1676   }
1677 }
1678 node {
1679   name: "cond/Const"
1680   op: "Const"
1681   input: "^cond/switch_t"
1682   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1683   attr {
1684     key: "dtype"
1685     value {
1686       type: DT_STRING
1687     }
1688   }
1689   attr {
1690     key: "value"
1691     value {
1692       tensor {
1693         dtype: DT_STRING
1694         tensor_shape {
1695           dim {
1696             size: 1
1697           }
1698         }
1699         string_val: ""
1700       }
1701     }
1702   }
1703 }
1704 node {
1705   name: "cond/Merge"
1706   op: "Merge"
1707   input: "cond/Switch_1"
1708   input: "cond/Const"
1709   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1710   attr {
1711     key: "N"
1712     value {
1713       i: 2
1714     }
1715   }
1716   attr {
1717     key: "T"
1718     value {
1719       type: DT_STRING
1720     }
1721   }
1722 }
1723 node {
1724   name: "Identity"
1725   op: "Identity"
1726   input: "cond/Merge"
1727   device: "/job:localhost/replica:0/task:0/device:CPU:0"
1728   attr {
1729     key: "T"
1730     value {
1731       type: DT_STRING
1732     }
1733   }
1734 }
1735 library {
1736 }
1737 versions {
1738   producer: 27
1739 }
1740   )EOF";
1741 
1742   GrapplerItem item;
1743   CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph));
1744   item.fetch = {"Identity"};
1745   Tensor feed_tensor(DT_BOOL, {});
1746   feed_tensor.flat<bool>()(0) = false;
1747   item.feed.push_back({"Const_1", feed_tensor});
1748   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1749   ASSERT_EQ(tensors_expected.size(), 1);
1750 
1751   LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr);
1752   GraphDef output;
1753   Status status = optimizer.Optimize(nullptr, item, &output);
1754   TF_CHECK_OK(status);
1755   auto tensors_got = EvaluateNodes(output, item.fetch);
1756   ASSERT_EQ(tensors_got.size(), 1);
1757   test::ExpectTensorEqual<tstring>(tensors_got[0], tensors_expected[0]);
1758 
1759   EXPECT_EQ(output.node_size(), 8);
1760 
1761   // No rewrite because branch has a constant feed node.
1762   bool found = false;
1763   for (const NodeDef& node : output.node()) {
1764     if (node.name() == "cond/Merge") {
1765       EXPECT_EQ(node.op(), "Merge");
1766       ASSERT_EQ(node.input_size(), 2);
1767       EXPECT_EQ(node.input(0), "cond/Switch_1");
1768       EXPECT_EQ(node.input(1), "cond/Const");
1769       found = true;
1770       break;
1771     }
1772   }
1773   EXPECT_TRUE(found);
1774 }
1775 
1776 }  // namespace grappler
1777 }  // namespace tensorflow
1778