xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <complex>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/math_ops.h"
25 #include "tensorflow/cc/ops/nn_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
32 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
33 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/test.h"
37 
38 namespace tensorflow {
39 namespace grappler {
40 
41 namespace {
42 
43 constexpr char kHoistFactorOptimizerDiv[] =
44     "ArithmeticOptimizer/HoistCommonFactor_Div_";
45 
46 constexpr char kHoistFactorOptimizerMul[] =
47     "ArithmeticOptimizer/HoistCommonFactor_Mul_";
48 
49 constexpr char kHoistFactorOptimizerAdd[] =
50     "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
51 
52 constexpr char kSimplifyAggregationConst[] =
53     "ArithmeticOptimizer/SimplifyAggregation_Const_";
54 
55 constexpr char kSimplifyAggregationMul[] =
56     "ArithmeticOptimizer/SimplifyAggregation_Mul_";
57 
58 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)59 string HoistMulName(const string& name) {
60   return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
61 }
62 
63 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)64 string HoistDivName(const string& name) {
65   return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
66 }
67 
68 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)69 string HoistAddName(const string& name) {
70   return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
71 }
72 
73 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)74 string AggregationConstName(const string& name) {
75   return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
76 }
77 
78 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)79 string AggregationMulName(const string& name) {
80   return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
81 }
82 
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)83 void VerifyGraphsMatch(const GraphDef& original_graph,
84                        const GraphDef& optimized_graph, int line) {
85   EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
86   for (int i = 0; i < original_graph.node_size(); ++i) {
87     const NodeDef& original = original_graph.node(i);
88     const NodeDef& optimized = optimized_graph.node(i);
89     EXPECT_EQ(original.name(), optimized.name()) << line;
90     EXPECT_EQ(original.op(), optimized.op()) << line;
91     EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
92     for (int j = 0; j < original.input_size(); ++j) {
93       EXPECT_EQ(original.input(j), optimized.input(j)) << line;
94     }
95   }
96 }
97 }  // namespace
98 
TEST_F(ArithmeticOptimizerTest,NoOp)99 TEST_F(ArithmeticOptimizerTest, NoOp) {
100   // This trivial graph is so basic there's nothing to optimize.
101   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
102   GrapplerItem item;
103   CHECK(fake_input.NextItem(&item));
104 
105   ArithmeticOptimizer optimizer;
106   GraphDef output;
107   Status status = optimizer.Optimize(nullptr, item, &output);
108   TF_EXPECT_OK(status);
109   VerifyGraphsMatch(item.graph, output, __LINE__);
110 }
111 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTile)112 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
113   // Graph from b/176172427
114   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
115   Output input =
116       ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
117                        ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
118   Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
119   Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
120   Output output = ops::Identity(s.WithOpName("output"), multiply);
121 
122   GrapplerItem item;
123   item.fetch = {"output"};
124   TF_CHECK_OK(s.ToGraphDef(&item.graph));
125   auto tensor =
126       GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
127   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
128   ASSERT_EQ(expected.size(), 1);
129 
130   GraphDef g;
131   ArithmeticOptimizer optimizer;
132   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
133   OptimizeTwiceAndPrune(&optimizer, &item, &g);
134   EXPECT_EQ(g.node_size(), 4);
135 
136   ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
137   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
138 
139   NodeMap node_map(&g);
140   const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
141   const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
142   const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
143   ASSERT_NE(t, nullptr);
144   ASSERT_NE(c, nullptr);
145   EXPECT_EQ(t->op(), "Tile");
146   ASSERT_EQ(t->input_size(), 2);
147   EXPECT_EQ(t->input(0), "input");
148   EXPECT_EQ(t->input(1), c->name());
149   EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
150   EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
151 
152   auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
153   ASSERT_EQ(result.size(), 1);
154   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
155 }
156 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTilePreserveControl)157 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
158   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
159   Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
160                                   ops::Placeholder::Shape({1, 1, 1}));
161   Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
162                            1.0f, {1, 2, 1});
163   Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
164   Output output = ops::Identity(s.WithOpName("output"), multiply);
165 
166   GrapplerItem item;
167   item.fetch = {"output"};
168   TF_CHECK_OK(s.ToGraphDef(&item.graph));
169   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
170   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
171   ASSERT_EQ(expected.size(), 1);
172 
173   GraphDef g;
174   ArithmeticOptimizer optimizer;
175   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
176   OptimizeTwiceAndPrune(&optimizer, &item, &g);
177   EXPECT_EQ(g.node_size(), 4);
178 
179   ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
180   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
181 
182   NodeMap node_map(&g);
183   const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
184   const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
185   ASSERT_NE(c, nullptr);
186   ASSERT_EQ(c->input_size(), 1);
187   EXPECT_TRUE(IsControlInput(c->input(0)));
188   EXPECT_EQ(c->input(0), "^input");
189 }
190 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNoBroadcast)191 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
192   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
193   Output input =
194       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
195   Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
196   Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
197   Output output = ops::Identity(s.WithOpName("output"), multiply);
198 
199   GrapplerItem item;
200   item.fetch = {"output"};
201   TF_CHECK_OK(s.ToGraphDef(&item.graph));
202   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
203   auto expected =
204       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
205   ASSERT_EQ(expected.size(), 1);
206 
207   GraphDef g;
208   ArithmeticOptimizer optimizer;
209   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
210   OptimizeTwiceAndPrune(&optimizer, &item, &g);
211   EXPECT_EQ(g.node_size(), 4);
212 
213   VerifyGraphsMatch(item.graph, g, __LINE__);
214 
215   auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
216   ASSERT_EQ(result.size(), 1);
217   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
218 }
219 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotConst)220 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
221   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
222   Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
223                                    ops::Placeholder::Shape({1, 1, 1}));
224   Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
225                                    ops::Placeholder::Shape({1, 2, 1}));
226   Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
227   Output output = ops::Identity(s.WithOpName("output"), multiply);
228 
229   GrapplerItem item;
230   item.fetch = {"output"};
231   TF_CHECK_OK(s.ToGraphDef(&item.graph));
232   auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
233   auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
234   auto expected = EvaluateNodes(item.graph, item.fetch,
235                                 {{"input1", tensor1}, {"input2", tensor2}});
236   ASSERT_EQ(expected.size(), 1);
237 
238   GraphDef g;
239   ArithmeticOptimizer optimizer;
240   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
241   OptimizeTwiceAndPrune(&optimizer, &item, &g);
242   EXPECT_EQ(g.node_size(), 4);
243 
244   VerifyGraphsMatch(item.graph, g, __LINE__);
245 
246   auto result = EvaluateNodes(item.graph, item.fetch,
247                               {{"input1", tensor1}, {"input2", tensor2}});
248   ASSERT_EQ(result.size(), 1);
249   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
250 }
251 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotOnes)252 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
253   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
254   Output input =
255       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
256   Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
257   Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
258   Output output = ops::Identity(s.WithOpName("output"), multiply);
259 
260   GrapplerItem item;
261   item.fetch = {"output"};
262   TF_CHECK_OK(s.ToGraphDef(&item.graph));
263   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
264   auto expected =
265       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
266   ASSERT_EQ(expected.size(), 1);
267 
268   GraphDef g;
269   ArithmeticOptimizer optimizer;
270   EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
271   OptimizeTwiceAndPrune(&optimizer, &item, &g);
272   EXPECT_EQ(g.node_size(), 4);
273 
274   VerifyGraphsMatch(item.graph, g, __LINE__);
275 
276   auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
277   ASSERT_EQ(result.size(), 1);
278   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
279 }
280 
TEST_F(ArithmeticOptimizerTest,ReduceUpsamplingDims)281 TEST_F(ArithmeticOptimizerTest, ReduceUpsamplingDims) {
282   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
283   Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
284                                   ops::Placeholder::Shape({1, 22, 48, 64}));
285   Output reshape_a = ops::Reshape(
286       s.WithOpName("reshape_a"), input,
287       ops::Const(s.WithOpName("shape_a"), {1, 22, 1, 48, 1, 64}, {6}));
288   Output tile =
289       ops::Tile(s.WithOpName("tile"), reshape_a,
290                 ops::Const(s.WithOpName("multiples"), {1, 1, 2, 1, 2, 1}, {6}));
291   Output reshape_b =
292       ops::Reshape(s.WithOpName("reshape_b"), tile,
293                    ops::Const(s.WithOpName("shape_b"), {1, 44, 96, 64}));
294   Output output = ops::Identity(s.WithOpName("output"), reshape_b);
295 
296   GrapplerItem item;
297   item.fetch = {"output"};
298   TF_CHECK_OK(s.ToGraphDef(&item.graph));
299   auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 22, 48, 64}));
300   auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
301   ASSERT_EQ(expected.size(), 1);
302 
303   GraphDef g;
304   ArithmeticOptimizer optimizer;
305   EnableOnlyReduceUpsamplingDims(&optimizer);
306   OptimizeTwiceAndPrune(&optimizer, &item, &g);
307   EXPECT_EQ(g.node_size(), 8);
308 
309   ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
310   ASSERT_EQ(CountOpNodes(g, "Reshape"), 2);
311   ASSERT_EQ(CountOpNodes(g, "Const"), 3);
312 
313   NodeMap node_map(&g);
314   const string p = "ArithmeticOptimizer/ReduceUpsamplingDims";
315   const NodeDef* ra =
316       node_map.GetNode(absl::StrCat(p, "_", "Reshape_reshape_b"));
317   const NodeDef* rb = node_map.GetNode("reshape_b");
318   const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_reshape_b"));
319   ASSERT_NE(ra, nullptr);
320   ASSERT_NE(rb, nullptr);
321   ASSERT_NE(t, nullptr);
322 
323   ASSERT_EQ(rb->input_size(), 2);
324   EXPECT_EQ(rb->input(0), t->name());
325   ASSERT_EQ(t->input_size(), 2);
326   EXPECT_EQ(t->input(0), ra->name());
327   ASSERT_EQ(ra->input_size(), 2);
328   EXPECT_EQ(ra->input(0), "input");
329 
330   {
331     auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
332     ASSERT_EQ(result.size(), 1);
333     test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
334   }
335 
336   // Check to make sure the first reshape is removed
337   EnableOnlyRemoveRedundantReshape(&optimizer);
338   OptimizeTwiceAndPrune(&optimizer, &item, &g);
339   EXPECT_EQ(g.node_size(), 6);
340 
341   {
342     auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
343     ASSERT_EQ(result.size(), 1);
344     test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
345   }
346 }
347 
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)348 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
349   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
350   Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
351   Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
352   Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
353   Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
354   Output id = ops::Identity(s.WithOpName("id"), mul);
355   Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
356 
357   GrapplerItem item;
358   item.fetch = {"id", "id2"};
359   TF_CHECK_OK(s.ToGraphDef(&item.graph));
360   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
361   ASSERT_EQ(tensors_expected.size(), 2);
362 
363   GraphDef output;
364   ArithmeticOptimizer optimizer;
365   EnableOnlyReplaceMulWithSquare(&optimizer);
366   OptimizeAndPrune(&optimizer, &item, &output);
367 
368   EXPECT_EQ(output.node_size(), 6);
369 
370   NodeMap node_map(&output);
371   const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
372   const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
373 
374   ASSERT_NE(square_node, nullptr);
375   EXPECT_EQ(square_node->op(), "Square");
376   ASSERT_EQ(square_node->input_size(), 2);
377   EXPECT_EQ(square_node->input(0), "c");
378   EXPECT_EQ(square_node->input(1), "^d");
379 
380   const NodeDef* square_node2 =
381       node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
382   ASSERT_NE(square_node2, nullptr);
383   EXPECT_EQ(square_node2->op(), "Square");
384   ASSERT_EQ(square_node2->input_size(), 1);
385   EXPECT_EQ(square_node2->input(0), "d");
386 
387   auto tensors = EvaluateNodes(output, item.fetch);
388   ASSERT_EQ(tensors.size(), 2);
389   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
390 }
391 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshape)392 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshape) {
393   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
394   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
395                               ops::Placeholder::Shape({3, 5, 7, 11}));
396   // Stack creates Pack nodes
397   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
398   Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
399   Output o = ops::Identity(s.WithOpName("output"), c);
400 
401   GrapplerItem item;
402   item.fetch = {"output"};
403   TF_CHECK_OK(s.ToGraphDef(&item.graph));
404   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
405   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
406   ASSERT_EQ(expected.size(), 1);
407 
408   GraphDef g;
409   ArithmeticOptimizer optimizer;
410   EnableOnlyReplacePackWithTileReshape(&optimizer);
411   OptimizeAndPrune(&optimizer, &item, &g);
412 
413   EXPECT_EQ(g.node_size(), 6);
414   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
415   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
416   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
417   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
418 
419   NodeMap node_map(&g);
420   const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
421   const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
422   const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
423   const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
424   const NodeDef* r_node = node_map.GetNode(absl::StrCat(p, "_", "Reshape_c"));
425   const NodeDef* a_node = node_map.GetNode("a");
426   ASSERT_NE(t_node, nullptr);
427   ASSERT_NE(c_node, nullptr);
428   ASSERT_NE(s_node, nullptr);
429   ASSERT_NE(r_node, nullptr);
430   ASSERT_NE(a_node, nullptr);
431 
432   EXPECT_EQ(c_node->op(), "Const");
433   EXPECT_EQ(s_node->op(), "Const");
434 
435   // Check Reshape properties
436   ASSERT_EQ(r_node->input_size(), 2);
437   EXPECT_EQ(r_node->op(), "Reshape");
438   EXPECT_EQ(r_node->input(0), t_node->name());
439   EXPECT_EQ(r_node->input(1), s_node->name());
440 
441   // Check Tile properties
442   ASSERT_EQ(t_node->input_size(), 2);
443   EXPECT_EQ(t_node->op(), "Tile");
444   EXPECT_EQ(t_node->input(0), a_node->name());
445   EXPECT_EQ(t_node->input(1), c_node->name());
446   EXPECT_EQ(t_node->attr().at("T").type(), DT_FLOAT);
447   EXPECT_EQ(t_node->attr().at("Tmultiples").type(),
448             c_node->attr().at("dtype").type());
449 
450   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
451   ASSERT_EQ(result.size(), 1);
452   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
453 }
454 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeControlDeps)455 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeControlDeps) {
456   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
457   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
458                               ops::Placeholder::Shape({3, 5, 7, 11}));
459 
460   Output x = ops::Identity(s.WithOpName("x"), a);
461   Output y = ops::Identity(s.WithOpName("y"), a);
462 
463   Output b = ops::Stack(s.WithOpName("b").WithControlDependencies(x), {a, a},
464                         ops::Stack::Axis(3));
465   Output c = ops::Stack(s.WithOpName("c").WithControlDependencies(y), {b, b},
466                         ops::Stack::Axis(2));
467   Output o = ops::Identity(s.WithOpName("output"), c);
468 
469   GrapplerItem item;
470   item.fetch = {"output"};
471   item.keep_ops = {"x", "y"};
472   TF_CHECK_OK(s.ToGraphDef(&item.graph));
473   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
474   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
475   ASSERT_EQ(expected.size(), 1);
476 
477   GraphDef g;
478   ArithmeticOptimizer optimizer;
479   EnableOnlyReplacePackWithTileReshape(&optimizer);
480   OptimizeAndPrune(&optimizer, &item, &g);
481 
482   EXPECT_EQ(g.node_size(), 8);
483   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
484   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
485   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
486   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
487   EXPECT_EQ(CountOpNodes(g, "Identity"), 3);
488 
489   NodeMap node_map(&g);
490   const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
491   const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
492   const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
493   const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
494   const NodeDef* a_node = node_map.GetNode("a");
495   ASSERT_NE(t_node, nullptr);
496   ASSERT_NE(c_node, nullptr);
497   ASSERT_NE(s_node, nullptr);
498   ASSERT_NE(a_node, nullptr);
499 
500   ASSERT_EQ(t_node->input_size(), 4);
501   EXPECT_EQ(t_node->op(), "Tile");
502   EXPECT_EQ(t_node->input(0), a_node->name());
503   EXPECT_EQ(t_node->input(1), c_node->name());
504   EXPECT_EQ(t_node->input(2), "^y");
505   EXPECT_EQ(t_node->input(3), "^x");
506 
507   ASSERT_EQ(c_node->input_size(), 1);
508   EXPECT_EQ(c_node->input(0), "^a");
509 
510   ASSERT_EQ(s_node->input_size(), 1);
511   ASSERT_EQ(s_node->input(0), "^a");
512 
513   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
514   ASSERT_EQ(result.size(), 1);
515   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
516 }
517 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileRemoveReshape)518 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileRemoveReshape) {
519   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
520   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
521                               ops::Placeholder::Shape({3, 5, 7, 11}));
522   // Stack creates Pack nodes
523   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
524   Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
525   Output r =
526       ops::Reshape(s.WithOpName("r"), c, ops::Const(s, {3, 10, 14, 11}, {4}));
527   Output o = ops::Identity(s.WithOpName("output"), r);
528 
529   GrapplerItem item;
530   item.fetch = {"output"};
531   TF_CHECK_OK(s.ToGraphDef(&item.graph));
532   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
533   auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
534   ASSERT_EQ(expected.size(), 1);
535 
536   GraphDef g;
537   ArithmeticOptimizer optimizer;
538   EnableOnlyReplacePackWithTileReshape(&optimizer);
539   OptimizeAndPrune(&optimizer, &item, &g);
540 
541   EXPECT_EQ(g.node_size(), 8);
542   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
543   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
544   EXPECT_EQ(CountOpNodes(g, "Const"), 3);
545   EXPECT_EQ(CountOpNodes(g, "Reshape"), 2);
546 
547   EnableOnlyRemoveRedundantReshape(&optimizer);
548   OptimizeAndPrune(&optimizer, &item, &g);
549 
550   EXPECT_EQ(g.node_size(), 6);
551   EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
552   EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
553   EXPECT_EQ(CountOpNodes(g, "Const"), 2);
554   EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
555 
556   auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
557   ASSERT_EQ(result.size(), 1);
558   test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
559 }
560 
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeOutOfRange)561 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeOutOfRange) {
562   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563   Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
564                               ops::Placeholder::Shape({3, 5, 7, 11}));
565   // Stack creates Pack nodes
566   Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(4));
567   Output o = ops::Identity(s.WithOpName("output"), b);
568 
569   GrapplerItem item;
570   item.fetch = {"output"};
571   TF_CHECK_OK(s.ToGraphDef(&item.graph));
572 
573   GraphDef g;
574   ArithmeticOptimizer optimizer;
575   EnableOnlyReplacePackWithTileReshape(&optimizer);
576   OptimizeAndPrune(&optimizer, &item, &g);
577 
578   VerifyGraphsMatch(item.graph, g, __LINE__);
579 }
580 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)581 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
582   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
583 
584   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
585   auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
586   auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
587   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
588   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
589   auto id = ops::Identity(s.WithOpName("id"), recip2);
590 
591   GrapplerItem item;
592   item.fetch = {"id"};
593   TF_CHECK_OK(s.ToGraphDef(&item.graph));
594   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
595   ASSERT_EQ(tensors_expected.size(), 1);
596 
597   GraphDef output;
598   ArithmeticOptimizer optimizer;
599   EnableOnlyRemoveInvolution(&optimizer);
600   OptimizeAndPrune(&optimizer, &item, &output);
601 
602   // Negation and Reciprocal nodes cancelled each other.
603   ASSERT_EQ(output.node_size(), 2);
604   EXPECT_EQ(output.node(1).name(), "id");
605   ASSERT_EQ(output.node(1).input_size(), 1);
606   EXPECT_EQ(output.node(1).input(0), "c");
607 
608   auto tensors = EvaluateNodes(output, item.fetch);
609   ASSERT_EQ(tensors.size(), 1);
610   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
611 }
612 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)613 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
614   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
615 
616   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
617   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
618   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
619   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
620   auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
621   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
622 
623   std::vector<string> fetch = {"id2"};
624 
625   GrapplerItem item;
626   item.fetch = fetch;
627   TF_CHECK_OK(s.ToGraphDef(&item.graph));
628   auto tensors_expected = EvaluateNodes(item.graph, fetch);
629   ASSERT_EQ(tensors_expected.size(), 1);
630 
631   GraphDef output;
632   ArithmeticOptimizer optimizer;
633   EnableOnlyRemoveInvolution(&optimizer);
634   OptimizeTwiceAndPrune(&optimizer, &item, &output);
635 
636   // Check that Reciprocal nodes were removed from the graph.
637   EXPECT_EQ(output.node_size(), 3);
638 
639   // And const directly flows into squeeze.
640   int found = 0;
641   for (const NodeDef& node : output.node()) {
642     if (node.name() == "squeeze") {
643       ASSERT_EQ(node.input_size(), 1);
644       EXPECT_EQ(node.input(0), "c");
645       found++;
646     } else if (node.name() == "id2") {
647       ASSERT_EQ(node.input_size(), 1);
648       EXPECT_EQ(node.input(0), "squeeze");
649       found++;
650     }
651   }
652   EXPECT_EQ(found, 2);
653 
654   auto tensors = EvaluateNodes(output, fetch);
655   ASSERT_EQ(tensors.size(), 1);
656   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
657 }
658 
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)659 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
660   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
661 
662   auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
663   auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
664   auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
665   auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
666   auto recip2 = ops::Reciprocal(
667       s.WithOpName("recip2").WithControlDependencies(squeeze), c);
668   auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
669 
670   std::vector<string> fetch = {"id2"};
671 
672   GrapplerItem item;
673   item.fetch = fetch;
674   TF_CHECK_OK(s.ToGraphDef(&item.graph));
675 
676   auto tensors_expected = EvaluateNodes(item.graph, fetch);
677   ASSERT_EQ(tensors_expected.size(), 1);
678 
679   GraphDef output;
680   ArithmeticOptimizer optimizer;
681   EnableOnlyRemoveInvolution(&optimizer);
682   OptimizeTwice(&optimizer, &item, &output);  // do not prune in this test
683 
684   // The optimizer should be a noop.
685   VerifyGraphsMatch(item.graph, output, __LINE__);
686 
687   auto tensors = EvaluateNodes(output, fetch);
688   ASSERT_EQ(tensors.size(), 1);
689   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
690 }
691 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)692 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
693   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
694   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
695   Output add = ops::Add(s.WithOpName("add"), x, x);
696   Output id = ops::Identity(s.WithOpName("id"), add);
697 
698   GrapplerItem item;
699   item.fetch = {"id"};
700   TF_CHECK_OK(s.ToGraphDef(&item.graph));
701 
702   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
703   ASSERT_EQ(tensors_expected.size(), 1);
704 
705   ArithmeticOptimizer optimizer;
706   GraphDef output;
707   OptimizeTwice(&optimizer, &item, &output);
708   NodeMap node_map(&output);
709 
710   EXPECT_EQ(output.node_size(), 5);
711 
712   const string optimized_const_name = AggregationConstName("add");
713   const string optimized_mul_name = AggregationMulName("add");
714 
715   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
716   ASSERT_NE(new_const, nullptr);
717   ASSERT_EQ(new_const->input_size(), 1);
718   EXPECT_EQ(new_const->input(0), "^x");
719   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
720             string("\0\0\0@", 4));
721 
722   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
723   ASSERT_NE(new_mul, nullptr);
724   ASSERT_EQ(new_mul->input_size(), 2);
725   EXPECT_EQ(new_mul->input(0), optimized_const_name);
726   EXPECT_EQ(new_mul->input(1), "x");
727 
728   const NodeDef* new_id = node_map.GetNode("id");
729   ASSERT_NE(new_id, nullptr);
730   ASSERT_EQ(new_id->input_size(), 1);
731   EXPECT_EQ(new_id->input(0), optimized_mul_name);
732 
733   auto tensors = EvaluateNodes(output, item.fetch);
734   ASSERT_EQ(tensors.size(), 1);
735   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
736 }
737 
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)738 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
739   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
740   Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
741   Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
742   Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
743   Output id = ops::Identity(s.WithOpName("id"), add);
744 
745   GrapplerItem item;
746   TF_CHECK_OK(s.ToGraphDef(&item.graph));
747 
748   std::vector<string> fetch = {"id"};
749   auto tensors_expected = EvaluateNodes(item.graph, fetch);
750   ASSERT_EQ(tensors_expected.size(), 1);
751 
752   ArithmeticOptimizer optimizer;
753   GraphDef output;
754   OptimizeTwice(&optimizer, &item, &output);
755   NodeMap node_map(&output);
756 
757   EXPECT_EQ(output.node_size(), 6);
758 
759   const string optimized_const_name = AggregationConstName("add");
760   const string optimized_mul_name = AggregationMulName("add");
761 
762   const NodeDef* new_const = node_map.GetNode(optimized_const_name);
763   ASSERT_NE(new_const, nullptr);
764   ASSERT_EQ(new_const->input_size(), 1);
765   EXPECT_EQ(new_const->input(0), "^x");
766   EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
767             string("\0\0\0@", 4));
768 
769   const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
770   ASSERT_NE(new_mul, nullptr);
771   ASSERT_EQ(new_mul->input_size(), 3);
772   EXPECT_EQ(new_mul->input(0), optimized_const_name);
773   EXPECT_EQ(new_mul->input(1), "x");
774   EXPECT_EQ(new_mul->input(2), "^y");
775 
776   const NodeDef* new_id = node_map.GetNode("id");
777   ASSERT_NE(new_id, nullptr);
778   ASSERT_EQ(new_id->input_size(), 1);
779   EXPECT_EQ(new_id->input(0), optimized_mul_name);
780 
781   auto tensors = EvaluateNodes(output, fetch);
782   ASSERT_EQ(tensors.size(), 1);
783   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
784 }
785 
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)786 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
787   // Test case from b/69059093.
788   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789   Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
790   Output add = ops::Add(s.WithOpName("Add"), p, p);
791   Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
792   Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
793   Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
794   Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
795   Output id = ops::Identity(s.WithOpName("id"), add6);
796 
797   GrapplerItem item;
798   item.fetch = {"id"};
799   TF_CHECK_OK(s.ToGraphDef(&item.graph));
800 
801   const std::vector<string> devices{
802       "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
803       "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
804   };
805   for (int i = 0; i < item.graph.node_size(); ++i) {
806     item.graph.mutable_node(i)->set_device(devices[i]);
807   }
808 
809   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
810   DisableAddToAddNCombining(&optimizer);
811 
812   GraphDef output;
813   DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
814 
815   // We expect the following rewrite(s) to occur:
816   //
817   // Mul(p,
818   //     Add_6(Add_4(Const(2), Const(2)),
819   //           Add_5(Const(2), Const(2)))
820   NodeMap node_map(&output);
821 
822   EXPECT_EQ(output.node_size(), 8);
823 
824   const NodeDef* id_node = node_map.GetNode("id");
825   ASSERT_NE(id_node, nullptr);
826   ASSERT_EQ(id_node->input_size(), 1);
827   EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
828 
829   const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
830   ASSERT_NE(mul_node, nullptr);
831   ASSERT_EQ(mul_node->input_size(), 2);
832   EXPECT_EQ(mul_node->input(0), "Placeholder");
833   EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
834 
835   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
836   ASSERT_NE(add_6_node, nullptr);
837   ASSERT_EQ(add_6_node->input_size(), 2);
838   EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
839   EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
840 
841   const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
842   ASSERT_NE(add_4_node, nullptr);
843   EXPECT_EQ(add_4_node->op(), "Add");
844   ASSERT_EQ(2, add_4_node->input_size());
845   EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
846   EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
847 
848   const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
849   ASSERT_NE(add_5_node, nullptr);
850   EXPECT_EQ(add_5_node->op(), "Add");
851   ASSERT_EQ(add_5_node->input_size(), 2);
852   EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
853   EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
854 
855   const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
856   ASSERT_NE(add_const_node, nullptr);
857   EXPECT_EQ(add_const_node->op(), "Const");
858   ASSERT_EQ(add_const_node->input_size(), 1);
859   EXPECT_EQ(add_const_node->input(0), "^Placeholder");
860 
861   const NodeDef* add_1_const_node =
862       node_map.GetNode(AggregationConstName("Add_1"));
863   ASSERT_NE(add_1_const_node, nullptr);
864   EXPECT_EQ(add_1_const_node->op(), "Const");
865   ASSERT_EQ(add_1_const_node->input_size(), 1);
866   EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
867 }
868 
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)869 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
870   for (bool matching_shapes : {true, false}) {
871     for (bool use_addn : {true, false}) {
872       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
873       Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
874       Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
875       Output y2 = matching_shapes
876                       ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
877                       : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
878       Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
879       Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
880       Output id =
881           use_addn ? ops::Identity(s.WithOpName("id"),
882                                    ops::AddN(s.WithOpName("add"), {mul1, mul2}))
883                    : ops::Identity(s.WithOpName("id"),
884                                    ops::Add(s.WithOpName("add"), mul1, mul2));
885 
886       GrapplerItem item;
887       item.fetch = {"id"};
888       TF_CHECK_OK(s.ToGraphDef(&item.graph));
889       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
890       ASSERT_EQ(tensors_expected.size(), 1);
891       ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
892       EnableOnlyHoistCommonFactor(&optimizer);
893 
894       GraphDef output;
895       OptimizeTwice(&optimizer, &item, &output);
896 
897       // We expect the following rewrite(s) to occur:
898       //
899       //        Add                 Mul
900       //      /    \               /   \
901       //    Mul    Mul       ->   x    Add
902       //    / \    / \                 / \
903       //   x  y1  y2  x              y1   y2
904       //
905       // If "root" op is AddN and shapes does not match, this rewrite is not
906       // possible and graph should stay intact.
907       NodeMap node_map(&output);
908 
909       if (use_addn && !matching_shapes) {
910         VerifyGraphsMatch(item.graph, output, __LINE__);
911       } else {
912         EXPECT_EQ(output.node_size(), 9);
913 
914         const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
915         ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
916         ASSERT_EQ(new_add_node->input_size(), 2);
917         EXPECT_EQ(new_add_node->input(0), "y1");
918         EXPECT_EQ(new_add_node->input(1), "y2");
919 
920         const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
921         ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
922         ASSERT_EQ(new_mul_node->input_size(), 2);
923         EXPECT_EQ(new_mul_node->input(0), "x");
924         EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
925 
926         const NodeDef* id_node = node_map.GetNode("id");
927         ASSERT_NE(id_node, nullptr) << "Id node not found";
928         EXPECT_EQ(id_node->name(), "id");
929         ASSERT_EQ(id_node->input_size(), 1);
930         EXPECT_EQ(id_node->input(0), HoistMulName("add"));
931       }
932       auto tensors = EvaluateNodes(output, item.fetch);
933       ASSERT_EQ(tensors.size(), 1);
934       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
935     }
936   }
937 }
938 
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)939 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
940   for (bool matching_shapes : {true, false}) {
941     for (bool use_addn : {true, false}) {
942       for (bool use_ints : {true, false}) {
943         tensorflow::Scope s = tensorflow::Scope::NewRootScope();
944         Output x = use_ints
945                        ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
946                        : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
947         Output y1 = use_ints
948                         ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
949                         : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
950         Output y2;
951         if (matching_shapes) {
952           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
953                         : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
954         } else {
955           y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
956                         : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
957         }
958         Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
959         Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
960         Output id =
961             use_addn
962                 ? ops::Identity(s.WithOpName("id"),
963                                 ops::AddN(s.WithOpName("add"), {div1, div2}))
964                 : ops::Identity(s.WithOpName("id"),
965                                 ops::Add(s.WithOpName("add"), div1, div2));
966 
967         GrapplerItem item;
968         item.fetch = {"id"};
969         TF_CHECK_OK(s.ToGraphDef(&item.graph));
970 
971         auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
972         ASSERT_EQ(tensors_expected.size(), 1);
973 
974         ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
975         EnableOnlyHoistCommonFactor(&optimizer);
976 
977         GraphDef output;
978         OptimizeTwice(&optimizer, &item, &output);
979 
980         // We expect the following rewrite(s) to occur:
981         //
982         //        Add                 Div
983         //      /    \               /   \
984         //    Div    Div       ->  Add    x
985         //    / \    / \           / \
986         //   y1  x  y2  x         y1  y2
987         //
988         // If "root" op is AddN and shapes does not match, this rewrite is not
989         // possible and graph should stay intact.
990         NodeMap node_map(&output);
991 
992         if ((use_addn && !matching_shapes) || use_ints) {
993           VerifyGraphsMatch(item.graph, output, __LINE__);
994         } else {
995           EXPECT_EQ(output.node_size(), 9);
996 
997           const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
998           ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
999           ASSERT_EQ(new_add_node->input_size(), 2);
1000           EXPECT_EQ(new_add_node->input(0), "y1");
1001           EXPECT_EQ(new_add_node->input(1), "y2");
1002 
1003           const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
1004           ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
1005           ASSERT_EQ(new_div_node->input_size(), 2);
1006           EXPECT_EQ(new_div_node->input(0), new_add_node->name());
1007           EXPECT_EQ(new_div_node->input(1), "x");
1008 
1009           const NodeDef* id_node = node_map.GetNode("id");
1010           ASSERT_TRUE(id_node != nullptr) << "Id node not found";
1011           EXPECT_EQ("id", id_node->name());
1012           ASSERT_EQ(id_node->input_size(), 1);
1013           EXPECT_EQ(id_node->input(0), HoistDivName("add"));
1014         }
1015         auto tensors = EvaluateNodes(output, item.fetch);
1016         ASSERT_EQ(tensors.size(), 1);
1017         if (use_ints) {
1018           test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
1019         } else {
1020           test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1021         }
1022       }
1023     }
1024   }
1025 }
1026 
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)1027 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
1028   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1029   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1030   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1031   Output z = ops::Complex(s.WithOpName("z"), re, im);
1032   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1033   Output conj = ops::Conj(s.WithOpName("conj"), z);
1034   Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
1035 
1036   GrapplerItem item;
1037   item.fetch = {"trans"};
1038   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1039 
1040   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1041   ASSERT_EQ(tensors_expected.size(), 1);
1042 
1043   ArithmeticOptimizer optimizer;
1044   GraphDef output;
1045   OptimizeTwice(&optimizer, &item, &output);
1046   NodeMap node_map(&output);
1047 
1048   EXPECT_EQ(output.node_size(), 7);
1049 
1050   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1051   const string optimized_name = absl::StrCat(p, "_", "trans");
1052 
1053   const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
1054   ASSERT_NE(trans_fused_node, nullptr);
1055   EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
1056   ASSERT_EQ(trans_fused_node->input_size(), 2);
1057   EXPECT_EQ(trans_fused_node->input(0), "z");
1058   EXPECT_EQ(trans_fused_node->input(1), "perm");
1059 
1060   auto tensors = EvaluateNodes(output, item.fetch);
1061   ASSERT_EQ(tensors.size(), 1);
1062   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1063 }
1064 
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)1065 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
1066   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1067 
1068   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1069   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1070   Output z = ops::Complex(s.WithOpName("z"), re, im);
1071   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1072   Output conj = ops::Conj(s.WithOpName("conj"), z);
1073   Output transp =
1074       ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
1075 
1076   GrapplerItem item;
1077   item.fetch = {"conjugate_trans"};
1078   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1079 
1080   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1081   ASSERT_EQ(tensors_expected.size(), 1);
1082 
1083   ArithmeticOptimizer optimizer;
1084   GraphDef output;
1085   OptimizeTwice(&optimizer, &item, &output);
1086   NodeMap node_map(&output);
1087 
1088   EXPECT_EQ(output.node_size(), 7);
1089 
1090   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1091   const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
1092 
1093   const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
1094   ASSERT_NE(conjugate_trans_fused_node, nullptr);
1095   EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
1096   ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
1097   EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
1098   EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
1099 
1100   auto tensors = EvaluateNodes(output, item.fetch);
1101   ASSERT_EQ(tensors.size(), 1);
1102   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1103 }
1104 
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)1105 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
1106   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1107   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1108   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1109   Output z = ops::Complex(s.WithOpName("z"), re, im);
1110   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1111   Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
1112   Output conj = ops::Conj(s.WithOpName("conj"), trans);
1113 
1114   GrapplerItem item;
1115   item.fetch = {"conj"};
1116   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1117 
1118   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1119   ASSERT_EQ(tensors_expected.size(), 1);
1120 
1121   ArithmeticOptimizer optimizer;
1122   GraphDef output;
1123   OptimizeTwice(&optimizer, &item, &output);
1124   NodeMap node_map(&output);
1125 
1126   EXPECT_EQ(output.node_size(), 7);
1127 
1128   const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1129   const string optimized_name = absl::StrCat(p, "_", "conj");
1130 
1131   const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
1132   ASSERT_NE(conj_fused_node, nullptr);
1133   EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
1134   ASSERT_EQ(conj_fused_node->input_size(), 2);
1135   EXPECT_EQ(conj_fused_node->input(0), "z");
1136   EXPECT_EQ(conj_fused_node->input(1), "perm");
1137 
1138   auto tensors = EvaluateNodes(output, item.fetch);
1139   ASSERT_EQ(tensors.size(), 1);
1140   test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1141 }
1142 
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)1143 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
1144   for (const string matmul_type :
1145        {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
1146     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1147 
1148     Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1149     Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1150     Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1151     Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
1152     Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
1153 
1154     Output matmul;
1155     auto matmul_op = s.WithOpName("matmul");
1156     if (matmul_type == "MatMul") {
1157       matmul = ops::MatMul(matmul_op, trans_a, trans_b);
1158     } else if (matmul_type == "SparseMatMul") {
1159       matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
1160     } else if (matmul_type == "BatchMatMul") {
1161       matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
1162     } else if (matmul_type == "BatchMatMulV2") {
1163       matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
1164     }
1165 
1166     auto identity = ops::Identity(s.WithOpName("identity"), matmul);
1167 
1168     GrapplerItem item;
1169     item.fetch = {"identity"};
1170     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1171 
1172     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1173     ASSERT_EQ(tensors_expected.size(), 1);
1174 
1175     ArithmeticOptimizer optimizer;
1176     EnableOnlyFoldTransposeIntoMatMul(&optimizer);
1177     GraphDef output;
1178     OptimizeTwice(&optimizer, &item, &output);
1179     NodeMap node_map(&output);
1180 
1181     EXPECT_EQ(output.node_size(), 8);
1182 
1183     const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1184     const string optimized_name = absl::StrCat(p, "_", "matmul");
1185 
1186     const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
1187     ASSERT_NE(matmul_fused_node, nullptr);
1188     ASSERT_EQ(matmul_fused_node->input_size(), 2);
1189     EXPECT_EQ(matmul_fused_node->input(0), "a");
1190     EXPECT_EQ(matmul_fused_node->input(1), "b");
1191 
1192     if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
1193       EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
1194       EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
1195     } else {
1196       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
1197       EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
1198     }
1199 
1200     const NodeDef* identity_node = node_map.GetNode("identity");
1201     ASSERT_NE(identity_node, nullptr);
1202     ASSERT_EQ(identity_node->input_size(), 1);
1203     EXPECT_EQ(identity_node->input(0), optimized_name);
1204 
1205     auto tensors = EvaluateNodes(output, item.fetch);
1206     ASSERT_EQ(tensors.size(), 1);
1207     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1208   }
1209 }
1210 
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)1211 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
1212   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1213 
1214   Output re_a =
1215       ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1216   Output im_a =
1217       ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
1218   Output re_b =
1219       ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1220   Output im_b =
1221       ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
1222   Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
1223   Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
1224   Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1225   Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
1226   Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
1227   Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
1228   Output identity = ops::Identity(s.WithOpName("identity"), matmul);
1229 
1230   GrapplerItem item;
1231   item.fetch = {"identity"};
1232   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1233 
1234   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1235   ASSERT_EQ(tensors_expected.size(), 1);
1236 
1237   ArithmeticOptimizer optimizer;
1238   GraphDef output;
1239   OptimizeTwice(&optimizer, &item, &output);
1240 
1241   NodeMap node_map(&output);
1242   EXPECT_EQ(output.node_size(), 12);
1243 
1244   const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1245   const string optimized_name = absl::StrCat(p, "_", "matmul");
1246 
1247   const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
1248   ASSERT_NE(optimized_matmul, nullptr);
1249   ASSERT_EQ(optimized_matmul->input_size(), 2);
1250   EXPECT_EQ(optimized_matmul->input(0), "a");
1251   EXPECT_EQ(optimized_matmul->input(1), "b");
1252   EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
1253   EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
1254 
1255   auto tensors = EvaluateNodes(output, item.fetch);
1256   ASSERT_EQ(tensors.size(), 1);
1257   test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
1258 }
1259 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)1260 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
1261   for (bool is_broadcastto : {false, true}) {
1262     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1263     Output inputs =
1264         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1265     Output inputs_shape = ops::Shape(s, inputs);
1266     // The target shape of the reshape is the concatenation of `batch_size` and
1267     // [3,28,28].
1268     Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1269                                    ops::Const(s, {1}, {1}));
1270     Output target_shape = ops::Concat(
1271         s.WithOpName("target_shape"),
1272         {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
1273     if (is_broadcastto) {
1274       Output outputs = ops::Identity(s.WithOpName("outputs"),
1275                                      ops::BroadcastTo(s, inputs, target_shape));
1276     } else {
1277       Output outputs = ops::Identity(s.WithOpName("outputs"),
1278                                      ops::Reshape(s, inputs, target_shape));
1279     }
1280 
1281     GrapplerItem item;
1282     item.fetch = {"outputs"};
1283     TF_CHECK_OK(s.ToGraphDef(&item.graph));
1284     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1285     auto tensors_expected =
1286         EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
1287     ASSERT_EQ(tensors_expected.size(), 1);
1288 
1289     GraphDef output;
1290     ArithmeticOptimizer optimizer;
1291     EnableOnlyRemoveRedundantReshape(&optimizer);
1292     OptimizeTwiceAndPrune(&optimizer, &item, &output);
1293 
1294     EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1295     EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
1296     auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
1297     ASSERT_EQ(tensors.size(), 1);
1298     test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1299   }
1300 }
1301 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)1302 TEST_F(ArithmeticOptimizerTest,
1303        RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
1304   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1305   Output inputs =
1306       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
1307   Output inputs_shape = ops::Shape(s, inputs);
1308   // The target shape of the reshape is the concatenation of `batch_size`, 3,
1309   // `height, and `width`.
1310   Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1311                                  ops::Const(s, {1}, {1}));
1312   Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
1313                              ops::Const(s, {1}, {1}));
1314   Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
1315                             ops::Const(s, {1}, {1}));
1316   Output target_shape =
1317       ops::Concat(s.WithOpName("target_shape"),
1318                   {batch_size, ops::Const(s, {3}, {1}), height, width},
1319                   ops::Const(s, {0}, {}));
1320   Output reshape = ops::Reshape(s, inputs, target_shape);
1321   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1322 
1323   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1324   GrapplerItem item;
1325   item.fetch = {"outputs"};
1326   item.feed = {{"Placeholder", x_t}};
1327   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1328 
1329   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1330   ASSERT_EQ(tensors_expected.size(), 1);
1331 
1332   GraphDef output;
1333   // Assume valid feed shape in aggressive mode.
1334   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1335   EnableOnlyRemoveRedundantReshape(&optimizer);
1336   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1337 
1338   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1339   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1340   ASSERT_EQ(tensors.size(), 1);
1341   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1342 }
1343 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)1344 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
1345   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1346   Output inputs =
1347       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1348   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1349   Output reshape = ops::Reshape(s, inputs, target_shape);
1350   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1351 
1352   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1353   GrapplerItem item;
1354   item.fetch = {"outputs"};
1355   item.feed = {{"Placeholder", x_t}};
1356   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1357 
1358   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1359   ASSERT_EQ(tensors_expected.size(), 1);
1360 
1361   GraphDef output;
1362   ArithmeticOptimizer optimizer;
1363   EnableOnlyRemoveRedundantReshape(&optimizer);
1364   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1365 
1366   // The reshape is preserved because the shape of the placeholder can be
1367   // different from the shape of the actual feed.
1368   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1369 
1370   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1371   ASSERT_EQ(tensors.size(), 1);
1372   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1373 }
1374 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)1375 TEST_F(ArithmeticOptimizerTest,
1376        RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
1377   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1378   Output inputs =
1379       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1380   Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1381   Output reshape = ops::Reshape(s, inputs, target_shape);
1382   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1383 
1384   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1385   GrapplerItem item;
1386   item.fetch = {"outputs"};
1387   item.feed = {{"Placeholder", x_t}};
1388   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1389 
1390   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1391   ASSERT_EQ(tensors_expected.size(), 1);
1392 
1393   GraphDef output;
1394   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1395   EnableOnlyRemoveRedundantReshape(&optimizer);
1396   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1397 
1398   EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1399   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1400   ASSERT_EQ(tensors.size(), 1);
1401   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1402 }
1403 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)1404 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
1405   // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1406   // be from [4,3,28,28] to [8,6,28,28].
1407   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1408   Output inputs =
1409       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1410   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1411   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1412 
1413   GrapplerItem item;
1414   item.fetch = {"outputs"};
1415   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1416   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1417   item.feed = {{"Placeholder", x_t}};
1418   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1419   ASSERT_EQ(tensors_expected.size(), 1);
1420 
1421   GraphDef output;
1422   ArithmeticOptimizer optimizer;
1423   EnableOnlyRemoveRedundantReshape(&optimizer);
1424   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1425 
1426   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1427   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1428   ASSERT_EQ(tensors.size(), 1);
1429   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1430 }
1431 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1432 TEST_F(ArithmeticOptimizerTest,
1433        RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1434   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1435   Output inputs =
1436       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1437   Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1438   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1439 
1440   GrapplerItem item;
1441   item.fetch = {"outputs"};
1442   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1443 
1444   GraphDef output;
1445   ArithmeticOptimizer optimizer;
1446   EnableOnlyRemoveRedundantReshape(&optimizer);
1447   OptimizeTwiceAndPrune(&optimizer, &item, &output);
1448 
1449   EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1450 }
1451 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1452 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1453   for (bool include_unary_chain : {false, true}) {
1454     // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The
1455     // two reshapes should be combined.
1456     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1457     Output nchw_vect_c =
1458         ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_FLOAT,
1459                          ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1460     Output transpose =
1461         ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1462                        ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1463     Output nhwc = ops::Reshape(
1464         s.WithOpName("nhwc"), transpose,
1465         ops::Const(
1466             s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1467             {8, 28, 28, 12}, {4}));
1468     Output flatten = ops::Reshape(
1469         s.WithOpName("flatten"),
1470         (include_unary_chain ? ops::Cos(s.WithOpName("Cos"), nhwc) : nhwc),
1471         ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1472     Output output0 = ops::Identity(s.WithOpName("output0"), flatten);
1473     Output output1 = ops::Identity(s.WithOpName("output1"), flatten);
1474 
1475     GraphDef graph;
1476     TF_CHECK_OK(s.ToGraphDef(&graph));
1477     auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28, 4}));
1478     auto eval =
1479         EvaluateNodes(graph, {"output0", "nhwc"}, {{"nchw_vect_c", x_t}});
1480 
1481     ASSERT_EQ(eval.size(), 2);
1482     auto expected_output_t = eval[0];
1483     auto nhwc_t = eval[1];
1484 
1485     {
1486       GrapplerItem item;
1487       item.graph = graph;
1488       item.fetch = {"output0", "output1"};
1489       item.feed = {{"nchw_vect_c", x_t}};
1490 
1491       GraphDef output;
1492       ArithmeticOptimizer optimizer;
1493       EnableOnlyRemoveRedundantReshape(&optimizer);
1494       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1495 
1496       EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1497       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1498       ASSERT_EQ(tensors.size(), 2);
1499       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1500       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1501     }
1502 
1503     // Test when the first reshape node output is the feed tensor.
1504     // (Expected no reshape removal to happen.)
1505     {
1506       GrapplerItem item;
1507       item.graph = graph;
1508       item.fetch = {"output0", "output1"};
1509       item.feed = {{"nhwc", nhwc_t}};
1510 
1511       GraphDef output;
1512       ArithmeticOptimizer optimizer;
1513       EnableOnlyRemoveRedundantReshape(&optimizer);
1514       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1515 
1516       EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1517       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1518       ASSERT_EQ(tensors.size(), 2);
1519       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1520       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1521     }
1522 
1523     // Test when the first reshape node output is consumed by multiple nodes
1524     // (Expected no reshape removal to happen.)
1525     {
1526       Output output2 = ops::Identity(s.WithOpName("output2"), nhwc);
1527       GraphDef graph;
1528       TF_CHECK_OK(s.ToGraphDef(&graph));
1529       GrapplerItem item;
1530       item.graph = graph;
1531       item.fetch = {"output0", "output1", "output2"};
1532       item.feed = {{"nchw_vect_c", x_t}};
1533 
1534       GraphDef output;
1535       ArithmeticOptimizer optimizer;
1536       EnableOnlyRemoveRedundantReshape(&optimizer);
1537       OptimizeTwiceAndPrune(&optimizer, &item, &output);
1538 
1539       EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1540       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1541       ASSERT_EQ(tensors.size(), 3);
1542       test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1543       test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1544       test::ExpectTensorEqual<float>(tensors[2], nhwc_t);
1545     }
1546   }
1547 }
1548 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1549 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1550   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1551   Output nhwc_uint8 =
1552       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1553   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1554   Output nchw_fp32 =
1555       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1556   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1557 
1558   GrapplerItem item;
1559   item.fetch = {"outputs"};
1560   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1561 
1562   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1563   auto tensors_expected =
1564       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1565   ASSERT_EQ(tensors_expected.size(), 1);
1566 
1567   GraphDef output;
1568   ArithmeticOptimizer optimizer;
1569   OptimizeAndPrune(&optimizer, &item, &output);
1570 
1571   const NodeDef* transpose_node = nullptr;
1572   for (const NodeDef& node : output.node()) {
1573     if (node.op() == "Transpose") {
1574       EXPECT_EQ(transpose_node, nullptr);
1575       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1576       transpose_node = &node;
1577     }
1578   }
1579   ASSERT_NE(transpose_node, nullptr);
1580 
1581   for (const NodeDef& node : output.node()) {
1582     if (node.op() == "Cast") {
1583       ASSERT_EQ(node.input_size(), 1);
1584       EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1585     }
1586   }
1587 
1588   auto tensors =
1589       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1590   ASSERT_EQ(tensors.size(), 1);
1591   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1592 }
1593 
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1594 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1595   // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1596   // under a /GPU:0 scope, because this test would fail if the testing machine
1597   // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1598   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1599   Output outputs =
1600       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1601   outputs = ops::Cast(s, outputs, DT_FLOAT);
1602   outputs = ops::SpaceToDepth(s, outputs, 2);
1603   outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1604 
1605   GrapplerItem item;
1606   item.fetch = {"outputs"};
1607   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1608 
1609   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1610   auto tensors_expected =
1611       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1612   ASSERT_EQ(tensors_expected.size(), 1);
1613 
1614   GraphDef output;
1615   ArithmeticOptimizer optimizer;
1616   OptimizeAndPrune(&optimizer, &item, &output);
1617 
1618   const NodeDef* s2d_node = nullptr;
1619   for (const NodeDef& node : output.node()) {
1620     if (node.op() == "SpaceToDepth") {
1621       EXPECT_EQ(s2d_node, nullptr);
1622       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1623       s2d_node = &node;
1624     }
1625   }
1626   ASSERT_NE(s2d_node, nullptr);
1627 
1628   for (const NodeDef& node : output.node()) {
1629     if (node.op() == "Cast") {
1630       ASSERT_EQ(node.input_size(), 1);
1631       EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1632     }
1633   }
1634 
1635   auto tensors =
1636       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1637   ASSERT_EQ(tensors.size(), 1);
1638   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1639 }
1640 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1641 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1642   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1643   Output nhwc_fp32 =
1644       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1645   Output nchw_fp32 =
1646       ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1647   Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1648   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1649 
1650   GrapplerItem item;
1651   item.fetch = {"outputs"};
1652   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1653 
1654   auto input_t =
1655       GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1656   auto tensors_expected =
1657       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1658   ASSERT_EQ(tensors_expected.size(), 1);
1659 
1660   GraphDef output;
1661   ArithmeticOptimizer optimizer;
1662   OptimizeAndPrune(&optimizer, &item, &output);
1663 
1664   const NodeDef* cast_node = nullptr;
1665   for (const NodeDef& node : output.node()) {
1666     if (node.op() == "Cast") {
1667       EXPECT_EQ(cast_node, nullptr);
1668       cast_node = &node;
1669       ASSERT_EQ(node.input_size(), 1);
1670       EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1671     }
1672   }
1673   ASSERT_NE(cast_node, nullptr);
1674 
1675   for (const NodeDef& node : output.node()) {
1676     if (node.op() == "Transpose") {
1677       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1678       ASSERT_EQ(node.input_size(), 2);
1679       EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1680     }
1681   }
1682 
1683   auto tensors =
1684       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1685   ASSERT_EQ(tensors.size(), 1);
1686   test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1687 }
1688 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1689 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1690   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1691   Output nhwc_uint8 =
1692       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1693   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1694   Output nhwc_fp32_reversed =
1695       ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1696   Output nchw_fp32_reversed =
1697       ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1698 
1699   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1700 
1701   GrapplerItem item;
1702   item.fetch = {"outputs"};
1703   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1704 
1705   auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1706   auto tensors_expected =
1707       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1708   ASSERT_EQ(tensors_expected.size(), 1);
1709 
1710   GraphDef output;
1711   ArithmeticOptimizer optimizer;
1712   OptimizeAndPrune(&optimizer, &item, &output);
1713 
1714   const NodeDef* reverse_node = nullptr;
1715   const NodeDef* transpose_node = nullptr;
1716   const NodeDef* cast_node = nullptr;
1717   for (const NodeDef& node : output.node()) {
1718     if (node.op() == "Transpose") {
1719       EXPECT_EQ(transpose_node, nullptr);
1720       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1721       transpose_node = &node;
1722     } else if (node.op() == "ReverseV2") {
1723       EXPECT_EQ(reverse_node, nullptr);
1724       EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1725       reverse_node = &node;
1726     } else if (node.op() == "Cast") {
1727       cast_node = &node;
1728     }
1729   }
1730   ASSERT_NE(cast_node, nullptr);
1731   ASSERT_NE(reverse_node, nullptr);
1732   ASSERT_NE(transpose_node, nullptr);
1733   ASSERT_EQ(reverse_node->input_size(), 2);
1734   EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1735   ASSERT_EQ(transpose_node->input_size(), 2);
1736   EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1737   ASSERT_EQ(cast_node->input_size(), 1);
1738   EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1739 
1740   auto tensors =
1741       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1742   ASSERT_EQ(tensors.size(), 1);
1743   test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1744 }
1745 
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1746 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1747   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1748   Output nhwc_uint8 =
1749       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1750   Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1751   Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1752   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1753 
1754   GrapplerItem item;
1755   item.fetch = {"outputs"};
1756   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1757 
1758   GraphDef output;
1759   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1760   CompareGraphs(item.graph, output);
1761 }
1762 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1763 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1764   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1765   Output nhwc_fp32 =
1766       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1767   Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1768   Output nchw_uint8 =
1769       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1770   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1771 
1772   GrapplerItem item;
1773   item.fetch = {"outputs"};
1774   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1775 
1776   GraphDef output;
1777   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1778   CompareGraphs(item.graph, output);
1779 }
1780 
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1781 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1782   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1783   Output nhwc_uint8 =
1784       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1785   Output nchw_uint8 =
1786       ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1787   Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1788   Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1789 
1790   GrapplerItem item;
1791   item.fetch = {"outputs"};
1792   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1793 
1794   GraphDef output;
1795   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1796   CompareGraphs(item.graph, output);
1797 }
1798 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1799 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1800   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1801   Output inputs_shape =
1802       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1803   Output inputs =
1804       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1805   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1806   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1807   Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1808   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1809   Output transpose2 =
1810       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1811   Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1812   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1813   Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1814 
1815   GrapplerItem item;
1816   item.fetch = {"id1", "id2"};
1817   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1818 
1819   GraphDef output;
1820   ArithmeticOptimizer optimizer;
1821   EnableOnlyRemoveIdentityTranspose(&optimizer);
1822   OptimizeAndPrune(&optimizer, &item, &output);
1823 
1824   std::set<string> nodes_after_optimization;
1825   for (const NodeDef& node : output.node()) {
1826     nodes_after_optimization.insert(node.name());
1827   }
1828   EXPECT_EQ(nodes_after_optimization,
1829             std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1830 }
1831 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1832 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1833   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1834   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1835   Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1836   Output z = ops::Complex(s.WithOpName("z"), re, im);
1837   Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1838   Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1839   Output id = ops::Identity(s.WithOpName("id"), transpose);
1840 
1841   GrapplerItem item;
1842   item.fetch = {"id"};
1843   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1844 
1845   GraphDef output;
1846   ArithmeticOptimizer optimizer;
1847   EnableOnlyRemoveIdentityTranspose(&optimizer);
1848   OptimizeAndPrune(&optimizer, &item, &output);
1849   NodeMap node_map(&output);
1850 
1851   EXPECT_EQ(output.node_size(), 5);
1852 
1853   const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1854   const string optimized_name = absl::StrCat(p, "_", "trans");
1855 
1856   const NodeDef* conj = node_map.GetNode(optimized_name);
1857   ASSERT_NE(conj, nullptr);
1858   EXPECT_EQ(conj->op(), "Conj");
1859   ASSERT_EQ(conj->input_size(), 1);
1860   EXPECT_EQ(conj->input(0), "z");
1861 }
1862 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1863 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1864   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1865   Output inputs_shape =
1866       ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1867   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1868                                    ops::Placeholder::Shape({8, 12, 28, 28}));
1869   OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1870   Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1871   Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1872   Output branch0 = split[0];
1873   Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1874   Output branch2 = split[2];
1875   Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1876   Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1877 
1878   GrapplerItem item;
1879   item.fetch = {"outputs"};
1880   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1881 
1882   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1883   item.feed = {{"inputs", x_t}};
1884   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1885   ASSERT_EQ(tensors_expected.size(), 1);
1886 
1887   GraphDef output;
1888   ArithmeticOptimizer optimizer;
1889   EnableOnlyRemoveIdentityTranspose(&optimizer);
1890   OptimizeAndPrune(&optimizer, &item, &output);
1891 
1892   for (const NodeDef& node : output.node()) {
1893     if (node.op() == "Concat") {
1894       ASSERT_EQ(node.input_size(), 3);
1895       EXPECT_EQ(node.input(0), "Split");
1896       EXPECT_EQ(node.input(1), "Split:1");
1897       EXPECT_EQ(node.input(2), "Split:2");
1898     }
1899   }
1900 
1901   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1902   ASSERT_EQ(tensors.size(), 1);
1903   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1904 }
1905 
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1906 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1907   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1908   Output inputs =
1909       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1910   Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1911   Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1912   Output outputs =
1913       ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1914                     ops::Const(s.WithOpName("outputs_const"), 1.0f));
1915 
1916   GrapplerItem item;
1917   item.fetch = {"outputs"};
1918   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1919 
1920   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1921   item.feed = {{"Placeholder", x_t}};
1922   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1923   ASSERT_EQ(tensors_expected.size(), 1);
1924 
1925   GraphDef output;
1926   ArithmeticOptimizer optimizer;
1927   EnableOnlyRemoveIdentityTranspose(&optimizer);
1928   OptimizeAndPrune(&optimizer, &item, &output);
1929 
1930   NodeMap node_map(&output);
1931   const NodeDef* outputs_node = node_map.GetNode("outputs");
1932   ASSERT_EQ(outputs_node->input_size(), 2);
1933   EXPECT_EQ(outputs_node->input(0), "outputs_const");
1934   EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1935 
1936   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1937   ASSERT_EQ(tensors.size(), 1);
1938   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1939 }
1940 
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1941 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1942   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1943   Output inputs_shape =
1944       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1945   Output inputs =
1946       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1947   Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1948   Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1949   Output transpose2 =
1950       ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1951   Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1952 
1953   GrapplerItem item;
1954   item.fetch = {"outputs"};
1955   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1956 
1957   GraphDef output;
1958   ArithmeticOptimizer optimizer;
1959   EnableOnlyRemoveIdentityTranspose(&optimizer);
1960   OptimizeAndPrune(&optimizer, &item, &output);
1961 
1962   EXPECT_EQ(output.node_size(), 6);
1963 }
1964 
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1965 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1966   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1967   Output inputs_shape =
1968       ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1969   Output inputs =
1970       ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1971   Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1972   Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1973   Output transpose1 = ops::Transpose(
1974       s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1975   Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1976   Output transpose2 =
1977       ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1978   Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1979 
1980   GrapplerItem item;
1981   item.fetch = {"id1"};
1982   TF_CHECK_OK(s.ToGraphDef(&item.graph));
1983 
1984   GraphDef output;
1985   ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1986   EnableOnlyRemoveIdentityTranspose(&optimizer);
1987   OptimizeAndPrune(&optimizer, &item, &output);
1988 
1989   std::set<string> nodes_after_optimization;
1990   for (const NodeDef& node : output.node()) {
1991     nodes_after_optimization.insert(node.name());
1992     if (node.name() == "id") {
1993       ASSERT_EQ(node.input_size(), 1);
1994       EXPECT_EQ(node.input(0), "inputs");
1995     }
1996     if (node.name() == "id1") {
1997       ASSERT_EQ(node.input_size(), 1);
1998       EXPECT_EQ(node.input(0), "id");
1999     }
2000   }
2001   EXPECT_EQ(nodes_after_optimization,
2002             std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
2003 }
2004 
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)2005 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
2006   for (bool swap_inputs : {false, true}) {
2007     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2008     Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2009                                      ops::Placeholder::Shape({1, 28, 28, 3}));
2010     Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2011     Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
2012                                          swap_inputs ? scale : inputs,
2013                                          swap_inputs ? inputs : scale);
2014     Output perm_nhwc_to_nchw =
2015         ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2016     Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2017                                         scaled_inputs, perm_nhwc_to_nchw);
2018     Output weights = ops::Const(s.WithOpName("weights"),
2019                                 Input::Initializer(127.0f, {5, 5, 3, 4}));
2020     Output conv =
2021         ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2022                     "VALID", ops::Conv2D::DataFormat("NCHW"));
2023     Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2024 
2025     GrapplerItem item;
2026     item.fetch = {"outputs"};
2027     TF_CHECK_OK(s.ToGraphDef(&item.graph));
2028 
2029     //    LOG(INFO) << "Before:\n" << item.graph.DebugString();
2030     GraphDef output;
2031     ArithmeticOptimizer optimizer;
2032     EnableOnlyFoldMultipleIntoConv(&optimizer);
2033     OptimizeTwiceAndPrune(&optimizer, &item, &output);
2034 
2035     //    LOG(INFO) << "After:\n"  << output.DebugString();
2036     NodeMap node_map(&output);
2037     // `conv` is now a folded convolution with scaled weights.
2038     const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2039     ASSERT_NE(folded_conv, nullptr);
2040 
2041     const NodeDef* folded_conv_weights =
2042         node_map.GetNode(folded_conv->input(1));
2043     ASSERT_NE(folded_conv_weights, nullptr);
2044     EXPECT_EQ(folded_conv_weights->op(), "Mul");
2045 
2046     // Its input should be a transpose of `inputs`.
2047     const NodeDef* transpose =
2048         node_map.GetNode(NodeName(folded_conv->input(0)));
2049     ASSERT_NE(transpose, nullptr);
2050     ASSERT_EQ(transpose->input_size(), 2);
2051     EXPECT_EQ(transpose->input(0), "inputs");
2052   }
2053 }
2054 
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)2055 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
2056   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2057   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2058                                    ops::Placeholder::Shape({8, 28, 28, 3}));
2059   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2060   Output scaled_inputs =
2061       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2062   Output perm_nhwc_to_nchw =
2063       ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2064   Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2065                                       scaled_inputs, perm_nhwc_to_nchw);
2066   Output weights = ops::Const(s.WithOpName("weights"),
2067                               Input::Initializer(127.0f, {5, 5, 3, 16}));
2068   Output conv =
2069       ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2070                   "VALID", ops::Conv2D::DataFormat("NCHW"));
2071   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2072 
2073   Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
2074   memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
2075          inputs_nchw_tensor.tensor_data().size());
2076 
2077   GrapplerItem item;
2078   item.fetch = {"outputs"};
2079   item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
2080   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2081 
2082   GraphDef output;
2083   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2084 
2085   item.graph.Swap(&output);
2086   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2087 
2088   NodeMap node_map(&output);
2089   const NodeDef* inputs_nchw_node_def =
2090       node_map.GetNode(inputs_nchw.node()->name());
2091   ASSERT_NE(inputs_nchw_node_def, nullptr);
2092   ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
2093   EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
2094             scaled_inputs.node()->name());
2095 }
2096 
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)2097 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
2098   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2099   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2100                                    ops::Placeholder::Shape({8, 28, 28, 28, 3}));
2101   Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2102   Output scaled_inputs =
2103       ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2104   Output weights = ops::Const(s.WithOpName("weights"),
2105                               Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
2106   Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
2107                             {1, 1, 1, 1, 1}, "VALID");
2108   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2109 
2110   GrapplerItem item;
2111   item.fetch = {"outputs"};
2112   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2113 
2114   GraphDef output;
2115   TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2116 
2117   item.graph.Swap(&output);
2118   TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2119 
2120   NodeMap node_map(&output);
2121   // `conv` is now a folded convolution on `inputs` and scaled weights.
2122   const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2123   ASSERT_NE(folded_conv, nullptr);
2124   ASSERT_EQ(folded_conv->input_size(), 2);
2125   CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
2126   const NodeDef* folded_conv_input_1 =
2127       node_map.GetNode(NodeName(folded_conv->input(1)));
2128   ASSERT_NE(folded_conv_input_1, nullptr);
2129   CHECK_EQ(folded_conv_input_1->op(), "Mul");
2130 }
2131 
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)2132 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
2133   // This unit test exercises two optimizations, folding mul into conv, and
2134   // reordering cast and transpose.
2135   //
2136   //   Conv2D(Transpose(Mul(Cast(I), S)), W)
2137   //     =>
2138   //   Conv2D(Transpose(Cast(I)), W*S)
2139   //     =>
2140   //   Conv2D(Cast(Transpose(I)), W*S)
2141   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2142 
2143   Output inputs =
2144       ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
2145   Output cast = ops::Cast(s, inputs, DT_FLOAT);
2146   Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
2147   Output transpose =
2148       ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
2149   Output weights = ops::Const(s.WithOpName("weights"),
2150                               Input::Initializer(127.0f, {5, 5, 3, 16}));
2151   Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
2152                             ops::Conv2D::DataFormat("NCHW"));
2153   Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2154 
2155   GrapplerItem item;
2156   item.fetch = {"outputs"};
2157   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2158 
2159   GraphDef output;
2160   ArithmeticOptimizer optimizer;  // all optimization stages are on
2161   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2162   NodeMap node_map(&output);
2163 
2164   // Expected names for reordered cast and transpose.
2165   const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
2166   const string optimized_cast_name = absl::StrCat(p, "float_Cast");
2167   const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
2168 
2169   // Expected names for folded multiply and conv.
2170   const string optimized_weights =
2171       "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
2172 
2173   const NodeDef* inputs_node = node_map.GetNode("Placeholder");
2174   const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
2175   const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
2176 
2177   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2178   const NodeDef* conv_node = node_map.GetNode("Conv2D");
2179 
2180   ASSERT_NE(inputs_node, nullptr);
2181   ASSERT_NE(transpose_node, nullptr);
2182   ASSERT_NE(cast_node, nullptr);
2183   ASSERT_NE(weights_node, nullptr);
2184   ASSERT_NE(conv_node, nullptr);
2185 
2186   EXPECT_EQ(output.node_size(), 7);
2187   ASSERT_EQ(transpose_node->input_size(), 2);
2188   EXPECT_EQ(transpose_node->input(0), inputs_node->name());
2189   ASSERT_EQ(cast_node->input_size(), 1);
2190   EXPECT_EQ(cast_node->input(0), transpose_node->name());
2191   ASSERT_EQ(conv_node->input_size(), 2);
2192   EXPECT_EQ(conv_node->input(0), cast_node->name());
2193   EXPECT_EQ(conv_node->input(1), weights_node->name());
2194 }
2195 
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)2196 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
2197   // This unit test exercises optimization of folding mul into conv for
2198   // multiple nodes in the graph.
2199   tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2200 
2201   GrapplerItem item;
2202   Output conv[2];
2203 
2204   for (int i = 0; i < 2; ++i) {
2205     Output inputs =
2206         ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
2207     Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
2208     Output weights = ops::Const(s.WithOpName("weights"),
2209                                 Input::Initializer(127.0f, {5, 5, 3, 16}));
2210     conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
2211                           ops::Conv2D::DataFormat("NCHW"));
2212   }
2213   Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
2214 
2215   item.fetch = {"outputs"};
2216   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2217 
2218   GraphDef output;
2219   ArithmeticOptimizer optimizer;
2220   EnableOnlyFoldMultipleIntoConv(&optimizer);
2221   OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2222 
2223   NodeMap node_map(&output);
2224 
2225   using absl::StrCat;
2226   const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
2227   const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
2228   const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
2229 
2230   const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2231   const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
2232   const NodeDef* conv_node = node_map.GetNode("Conv2D");
2233   const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
2234 
2235   ASSERT_NE(weights_node, nullptr);
2236   ASSERT_NE(weights_node_1, nullptr);
2237   ASSERT_NE(conv_node, nullptr);
2238   ASSERT_NE(conv_node_1, nullptr);
2239 
2240   ASSERT_EQ(conv_node->input_size(), 2);
2241   ASSERT_EQ(conv_node_1->input_size(), 2);
2242   EXPECT_EQ(conv_node->input(1), weights_node->name());
2243   EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
2244 }
2245 
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)2246 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
2247   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2248   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
2249                                    ops::Placeholder::Shape({2, 3}));
2250   Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
2251   Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
2252   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2253 
2254   GrapplerItem item;
2255   item.fetch = {"outputs"};
2256   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2257 
2258   auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
2259   item.feed = {{"inputs", x_t}};
2260   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2261   ASSERT_EQ(tensors_expected.size(), 1);
2262 
2263   GraphDef output;
2264   ArithmeticOptimizer optimizer;
2265   EnableOnlyRemoveRedundantBitcast(&optimizer);
2266 
2267   OptimizeAndPrune(&optimizer, &item, &output);
2268   NodeMap node_map(&output);
2269 
2270   // Bitcasts combined into a single op and inputs redirected to updated Bitcast
2271   EXPECT_EQ(output.node_size(), 3);
2272   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
2273   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
2274 
2275   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2276   ASSERT_EQ(tensors.size(), 1);
2277   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2278 }
2279 
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)2280 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
2281   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2282   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2283                                    ops::Placeholder::Shape({2, 3}));
2284   Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
2285   Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
2286   Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2287 
2288   GrapplerItem item;
2289   item.fetch = {"outputs"};
2290   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2291 
2292   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2293   item.feed = {{"inputs", x_t}};
2294   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2295   ASSERT_EQ(tensors_expected.size(), 1);
2296 
2297   GraphDef output;
2298   ArithmeticOptimizer optimizer;
2299   EnableOnlyRemoveRedundantBitcast(&optimizer);
2300 
2301   OptimizeAndPrune(&optimizer, &item, &output);
2302   NodeMap node_map(&output);
2303 
2304   // Bitcasts removed and inputs redirected to outputs
2305   EXPECT_EQ(output.node_size(), 2);
2306   EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
2307   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2308 
2309   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2310   ASSERT_EQ(tensors.size(), 1);
2311   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2312 }
2313 
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)2314 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
2315   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2316   Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2317                                    ops::Placeholder::Shape({2, 3}));
2318   Output cast = ops::Cast(s, inputs, DT_INT8);
2319   Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
2320 
2321   GrapplerItem item;
2322   item.fetch = {"outputs"};
2323   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2324 
2325   auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2326   item.feed = {{"inputs", x_t}};
2327   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2328   ASSERT_EQ(tensors_expected.size(), 1);
2329 
2330   GraphDef output;
2331   ArithmeticOptimizer optimizer;
2332   EnableOnlyRemoveRedundantCast(&optimizer);
2333 
2334   OptimizeAndPrune(&optimizer, &item, &output);
2335   NodeMap node_map(&output);
2336 
2337   // Cast removed and inputs redirected to outputs
2338   EXPECT_EQ(output.node_size(), 2);
2339   EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
2340   EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2341 
2342   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2343   ASSERT_EQ(tensors.size(), 1);
2344   test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2345 }
2346 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)2347 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
2348   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2349   tensorflow::Scope sx = s.NewSubScope("x");
2350   tensorflow::Scope sy = s.NewSubScope("y");
2351 
2352   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2353   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2354   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2355   auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
2356   auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
2357 
2358   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2359 
2360   GrapplerItem item;
2361   item.fetch = {"outputs"};
2362   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2363 
2364   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2365   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2366   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2367   std::vector<std::pair<string, Tensor>> feed = {
2368       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2369   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2370   ASSERT_EQ(tensors_expected.size(), 1);
2371 
2372   GraphDef output;
2373   ArithmeticOptimizer optimizer;
2374   EnableOnlyAddToAddNCombining(&optimizer);
2375 
2376   OptimizeAndPrune(&optimizer, &item, &output);
2377 
2378   // We expect the following rewrite(s) to occur:
2379   //
2380   //     +
2381   //    / \
2382   //   a   +         -->    AddN(a, b, c)
2383   //      / \
2384   //     b   c
2385   EXPECT_EQ(output.node_size(), 5);
2386 
2387   NodeMap node_map(&output);
2388 
2389   // check add tree was replaced with AddN
2390   const NodeDef* collapsed_add =
2391       node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2392   ASSERT_NE(collapsed_add, nullptr);
2393 
2394   EXPECT_EQ(collapsed_add->op(), "AddN");
2395   ASSERT_EQ(collapsed_add->input_size(), 3);
2396   EXPECT_EQ(collapsed_add->input(0), "a");
2397   EXPECT_EQ(collapsed_add->input(1), "b");
2398   EXPECT_EQ(collapsed_add->input(2), "c");
2399 
2400   // check output was re-wired to new node
2401   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2402   ASSERT_NE(updated_outputs, nullptr);
2403   ASSERT_EQ(updated_outputs->input_size(), 1);
2404   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2405 
2406   auto tensors = EvaluateNodes(output, item.fetch, feed);
2407   ASSERT_EQ(tensors.size(), 1);
2408   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2409 }
2410 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)2411 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
2412   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2413 
2414   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2415   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2416   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2417   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2418   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2419 
2420   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2421   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2422   auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
2423   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2424   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2425 
2426   auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
2427   auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
2428 
2429   GrapplerItem item;
2430   item.fetch = {"outputs"};
2431   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2432 
2433   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2434   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2435   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2436   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2437   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2438   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2439   std::vector<std::pair<string, Tensor>> feed = {
2440       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2441   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2442   ASSERT_EQ(tensors_expected.size(), 1);
2443 
2444   GraphDef output;
2445   ArithmeticOptimizer optimizer;
2446   EnableOnlyAddToAddNCombining(&optimizer);
2447 
2448   OptimizeAndPrune(&optimizer, &item, &output);
2449 
2450   // We expect the following rewrite(s) to occur:
2451   //
2452   //         *
2453   //      /     \
2454   //     +       +                        *
2455   //    / \     / \                    /     \
2456   //   +   c   x   + -->    AddN(a, b, c)  AddN(x, y, z))
2457   //  / \         / \
2458   // a   b       y   z
2459   EXPECT_EQ(output.node_size(), 10);
2460 
2461   NodeMap node_map(&output);
2462 
2463   // check left Add subtree replaced with AddN
2464   const NodeDef* collapsed_left =
2465       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2466   ASSERT_NE(collapsed_left, nullptr);
2467 
2468   EXPECT_EQ(collapsed_left->op(), "AddN");
2469   ASSERT_EQ(collapsed_left->input_size(), 3);
2470   EXPECT_EQ(collapsed_left->input(0), "a");
2471   EXPECT_EQ(collapsed_left->input(1), "b");
2472   EXPECT_EQ(collapsed_left->input(2), "c");
2473 
2474   // check right Add subtree replaced with AddN
2475   const NodeDef* collapsed_right =
2476       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2477   ASSERT_NE(collapsed_right, nullptr);
2478 
2479   EXPECT_EQ(collapsed_right->op(), "AddN");
2480   ASSERT_EQ(collapsed_right->input_size(), 3);
2481   EXPECT_EQ(collapsed_right->input(0), "x");
2482   EXPECT_EQ(collapsed_right->input(1), "y");
2483   EXPECT_EQ(collapsed_right->input(2), "z");
2484 
2485   // check that Mul inputs re-wired to new Nodes
2486   const NodeDef* updated_mul = node_map.GetNode("Mul");
2487   ASSERT_NE(updated_mul, nullptr);
2488 
2489   EXPECT_EQ(updated_mul->op(), "Mul");
2490   ASSERT_EQ(updated_mul->input_size(), 2);
2491   EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2492   EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2493 
2494   auto tensors = EvaluateNodes(output, item.fetch, feed);
2495   ASSERT_EQ(tensors.size(), 1);
2496   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2497 }
2498 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2499 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2500   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2501 
2502   auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2503   auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2504   auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2505   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2506   auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2507   auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2508   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2509 
2510   GrapplerItem item;
2511   item.fetch = {"outputs"};
2512   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2513 
2514   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2515   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2516   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2517   std::vector<std::pair<string, Tensor>> feed = {
2518       {"a", a_t}, {"b", b_t}, {"c", c_t}};
2519   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2520   ASSERT_EQ(tensors_expected.size(), 1);
2521 
2522   GraphDef output;
2523   ArithmeticOptimizer optimizer;
2524   EnableOnlyAddToAddNCombining(&optimizer);
2525 
2526   OptimizeAndPrune(&optimizer, &item, &output);
2527 
2528   // We expect the following rewrite(s) to occur:
2529   //
2530   //     +
2531   //    / \
2532   //   +   +     -->    AddN(a, b, b, c)
2533   //  / \ / \                   ^
2534   // a   b   c                  b added twice!
2535   EXPECT_EQ(output.node_size(), 5);
2536 
2537   NodeMap node_map(&output);
2538 
2539   // check Add tree replaced with AddN
2540   const NodeDef* collapsed_add =
2541       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2542   ASSERT_NE(collapsed_add, nullptr);
2543 
2544   EXPECT_EQ(collapsed_add->op(), "AddN");
2545   ASSERT_EQ(collapsed_add->input_size(), 4);
2546   EXPECT_EQ(collapsed_add->input(0), "a");
2547   EXPECT_EQ(collapsed_add->input(1), "b");
2548   EXPECT_EQ(collapsed_add->input(2), "b");
2549   EXPECT_EQ(collapsed_add->input(3), "c");
2550 
2551   auto tensors = EvaluateNodes(output, item.fetch, feed);
2552   ASSERT_EQ(tensors.size(), 1);
2553   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2554 }
2555 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2556 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2557   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2558 
2559   // unknown input shape propagated symbolically through the graph
2560   auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2561 
2562   // [a, b, c] have symbolically equal shapes
2563   auto a = ops::Sqrt(s.WithOpName("a"), input);
2564   auto b = ops::Square(s.WithOpName("b"), input);
2565   auto c = ops::Round(s.WithOpName("c"), input);
2566 
2567   // [add_ab, add_abc] shape must be inferred from inputs
2568   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2569   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2570 
2571   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2572 
2573   GrapplerItem item;
2574   item.fetch = {"outputs"};
2575   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2576 
2577   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2578   std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2579   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2580   ASSERT_EQ(tensors_expected.size(), 1);
2581 
2582   GraphDef output;
2583   ArithmeticOptimizer optimizer;
2584   EnableOnlyAddToAddNCombining(&optimizer);
2585 
2586   OptimizeAndPrune(&optimizer, &item, &output);
2587 
2588   // We expect the following rewrite(s) to occur:
2589   //
2590   //     +
2591   //    / \
2592   //   +   c      -->    AddN(a, b, c)
2593   //  / \
2594   // a   b
2595   EXPECT_EQ(output.node_size(), 6);
2596 
2597   NodeMap node_map(&output);
2598 
2599   // check add tree was replaced with AddN
2600   const NodeDef* collapsed_add =
2601       node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2602   ASSERT_NE(collapsed_add, nullptr);
2603   EXPECT_EQ(collapsed_add->op(), "AddN");
2604   ASSERT_EQ(collapsed_add->input_size(), 3);
2605   EXPECT_EQ(collapsed_add->input(0), "a");
2606   EXPECT_EQ(collapsed_add->input(1), "b");
2607   EXPECT_EQ(collapsed_add->input(2), "c");
2608 
2609   // check output was re-wired to new node
2610   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2611   ASSERT_NE(updated_outputs, nullptr);
2612   ASSERT_EQ(updated_outputs->input_size(), 1);
2613   EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2614 
2615   auto tensors = EvaluateNodes(output, item.fetch, feed);
2616   ASSERT_EQ(tensors.size(), 1);
2617   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2618 }
2619 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2620 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2621   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2622 
2623   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2624   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2625   auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2626   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2627   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2628 
2629   auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2630   auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2631   auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2632   auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2633   auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2634 
2635   auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2636   auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2637 
2638   GrapplerItem item;
2639   item.fetch = {"outputs"};
2640   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2641 
2642   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2643   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2644   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2645   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2646   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2647   auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2648   std::vector<std::pair<string, Tensor>> feed = {
2649       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2650   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2651   ASSERT_EQ(tensors_expected.size(), 1);
2652 
2653   GraphDef output;
2654   ArithmeticOptimizer optimizer;
2655   EnableOnlyAddToAddNCombining(&optimizer);
2656 
2657   OptimizeAndPrune(&optimizer, &item, &output);
2658 
2659   // We expect the following rewrite(s) to occur:
2660   //  1) [a, x], [b, y], [c, z] - aggregate same shapes first
2661   //  2) Build an aggregation tree minimizing cost of broadcast
2662   //
2663   //         +                              +
2664   //      /     \                       /       \
2665   //     +       +                     +       AddN(c, z)
2666   //    / \     / \                 /     \
2667   //   +   c   x   + -->    AddN(a, x)  AddN(b, y)
2668   //  / \         / \
2669   // a   b       y   z
2670   EXPECT_EQ(output.node_size(), 12);
2671   NodeMap node_map(&output);
2672 
2673   // expected names of outer and inner nodes
2674   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2675   string outer_0_add_name =
2676       "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2677   string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2678   string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2679   string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2680 
2681   // Add [a, x] first
2682   const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2683   ASSERT_NE(add_ax_node, nullptr);
2684   EXPECT_EQ(add_ax_node->op(), "AddN");
2685   ASSERT_EQ(add_ax_node->input_size(), 2);
2686   EXPECT_EQ(add_ax_node->input(0), "a");
2687   EXPECT_EQ(add_ax_node->input(1), "x");
2688 
2689   // Then add [b, y]
2690   const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2691   ASSERT_NE(add_by_node, nullptr);
2692   EXPECT_EQ(add_by_node->op(), "AddN");
2693   ASSERT_EQ(2, add_by_node->input_size());
2694   EXPECT_EQ(add_by_node->input(0), "b");
2695   EXPECT_EQ(add_by_node->input(1), "y");
2696 
2697   // Then add [c, z]
2698   const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2699   ASSERT_NE(add_cz_node, nullptr);
2700   EXPECT_EQ(add_cz_node->op(), "AddN");
2701   ASSERT_EQ(add_cz_node->input_size(), 2);
2702   EXPECT_EQ(add_cz_node->input(0), "c");
2703   EXPECT_EQ(add_cz_node->input(1), "z");
2704 
2705   // Then add results together starting from smaller shapes [a, x] + [b, y]
2706   const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2707   ASSERT_NE(outer_0_node, nullptr);
2708   EXPECT_EQ(outer_0_node->op(), "AddV2");
2709   ASSERT_EQ(outer_0_node->input_size(), 2);
2710   EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2711   EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2712 
2713   // And finally top level Add node
2714   const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2715   ASSERT_NE(outer_node, nullptr);
2716   EXPECT_EQ(outer_node->op(), "AddV2");
2717   ASSERT_EQ(outer_node->input_size(), 2);
2718   EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2719   EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2720 
2721   // And outputs reading new top level Add node
2722   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2723   ASSERT_NE(updated_outputs, nullptr);
2724   ASSERT_EQ(updated_outputs->input_size(), 1);
2725   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2726 
2727   auto tensors = EvaluateNodes(output, item.fetch, feed);
2728   ASSERT_EQ(tensors.size(), 1);
2729   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2730 }
2731 
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2732 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2733   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2734 
2735   // We have a small input with one unknown dimension
2736   auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2737 
2738   // And second input which is larger, but has the same unknown dimension
2739   // device spec prevents this node from rewriting
2740   auto d = "/device:CPU:0";
2741   auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2742   auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2743 
2744   // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2745   auto a = ops::Sqrt(s.WithOpName("a"), small);
2746   auto b = ops::Square(s.WithOpName("b"), large);
2747   auto c = ops::Round(s.WithOpName("c"), small);
2748 
2749   // [add_ab, add_abc] shape must be inferred from inputs
2750   auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2751   auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2752 
2753   auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2754 
2755   GrapplerItem item;
2756   item.fetch = {"outputs"};
2757   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2758 
2759   auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2760   auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2761   std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2762   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2763   ASSERT_EQ(tensors_expected.size(), 1);
2764 
2765   GraphDef output;
2766   ArithmeticOptimizer optimizer;
2767   EnableOnlyAddToAddNCombining(&optimizer);
2768   OptimizeAndPrune(&optimizer, &item, &output);
2769 
2770   // We expect the following rewrite(s) to occur: it's much cheaper to add small
2771   // tensors, and do the broadcast just once
2772   //
2773   //     +                  +
2774   //    / \                / \
2775   //   +   c      -->     +   b
2776   //  / \                / \
2777   // a   b              a   c
2778   EXPECT_EQ(output.node_size(), 9);
2779   NodeMap node_map(&output);
2780 
2781   // expected names of outer and inner nodes
2782   string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2783   string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2784 
2785   // outer Add node
2786   const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2787   ASSERT_NE(outer_add, nullptr);
2788   EXPECT_EQ(outer_add->op(), "AddV2");
2789   ASSERT_EQ(outer_add->input_size(), 2);
2790   EXPECT_EQ(outer_add->input(0), inner_add_name);
2791   EXPECT_EQ(outer_add->input(1), "b");
2792 
2793   // inner AddN node
2794   const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2795   ASSERT_NE(inner_add, nullptr);
2796   ASSERT_EQ(inner_add->input_size(), 2);
2797   EXPECT_EQ(inner_add->input(0), "a");
2798   EXPECT_EQ(inner_add->input(1), "c");
2799 
2800   // check output was re-wired to new node
2801   const NodeDef* updated_outputs = node_map.GetNode("outputs");
2802   ASSERT_NE(updated_outputs, nullptr);
2803   ASSERT_EQ(updated_outputs->input_size(), 1);
2804   EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2805 
2806   auto tensors = EvaluateNodes(output, item.fetch, feed);
2807   ASSERT_EQ(tensors.size(), 1);
2808   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2809 }
2810 
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2811 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2812   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2813   auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2814   auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2815   Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2816   Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2817   Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2818   Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2819   Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2820   Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2821   Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2822   Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2823   Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2824   Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2825   Output neg_x_with_dep = ops::Neg(
2826       s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2827   Output add_negx_with_dep_y =
2828       ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2829   auto add_all =
2830       ops::AddN(s.WithOpName("add_all"),
2831                 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2832                  sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2833 
2834   GrapplerItem item;
2835   item.fetch = {"add_all"};
2836   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2837 
2838   auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2839   auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2840   std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2841   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2842   ASSERT_EQ(tensors_expected.size(), 1);
2843 
2844   GraphDef output;
2845   ArithmeticOptimizer optimizer;
2846   EnableOnlyRemoveNegation(&optimizer);
2847   OptimizeTwice(&optimizer, &item, &output);
2848 
2849   EXPECT_EQ(output.node_size(), item.graph.node_size());
2850   int found = 0;
2851   for (int i = 0; i < output.node_size(); ++i) {
2852     const NodeDef& node = output.node(i);
2853     if (node.name() == "Add_negx_y") {
2854       ++found;
2855       EXPECT_EQ(node.op(), "Sub");
2856       ASSERT_EQ(node.input_size(), 2);
2857       EXPECT_EQ(node.input(0), "y");
2858       EXPECT_EQ(node.input(1), "x");
2859     } else if (node.name() == "Add_x_negy") {
2860       ++found;
2861       EXPECT_EQ(node.op(), "Sub");
2862       ASSERT_EQ(node.input_size(), 2);
2863       EXPECT_EQ(node.input(0), "x");
2864       EXPECT_EQ(node.input(1), "y");
2865     } else if (node.name() == "Add_negx_negy") {
2866       ++found;
2867       EXPECT_EQ(node.op(), "Sub");
2868       ASSERT_EQ(node.input_size(), 2);
2869       EXPECT_EQ(node.input(0), "Neg_x");
2870       EXPECT_EQ(node.input(1), "y");
2871     } else if (node.name() == "Sub_x_negy") {
2872       ++found;
2873       EXPECT_EQ(node.op(), "AddV2");
2874       ASSERT_EQ(node.input_size(), 2);
2875       EXPECT_EQ(node.input(0), "x");
2876       EXPECT_EQ(node.input(1), "y");
2877     } else if (node.name() == "Sub_negx_negy") {
2878       ++found;
2879       EXPECT_EQ(node.op(), "Sub");
2880       ASSERT_EQ(node.input_size(), 2);
2881       EXPECT_EQ(node.input(0), "y");
2882       EXPECT_EQ(node.input(1), "x");
2883     } else if (node.name() == "Add_negx_with_dep_y") {
2884       ++found;
2885       EXPECT_EQ(node.op(), "Sub");
2886       ASSERT_EQ(node.input_size(), 3);
2887       EXPECT_EQ(node.input(0), "y");
2888       EXPECT_EQ(node.input(1), "x");
2889       EXPECT_EQ(node.input(2), "^Add_x_y");
2890     }
2891   }
2892   EXPECT_EQ(found, 6);
2893 
2894   auto tensors = EvaluateNodes(output, item.fetch, feed);
2895   ASSERT_EQ(tensors.size(), 1);
2896   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2897 }
2898 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2899 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2900   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2901   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2902   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2903   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2904   Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2905 
2906   GrapplerItem item;
2907   item.fetch = {"output"};
2908   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2909   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2910   ASSERT_EQ(tensors_expected.size(), 1);
2911 
2912   GraphDef output;
2913   ArithmeticOptimizer optimizer;
2914   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2915   OptimizeAndPrune(&optimizer, &item, &output);
2916   auto tensors = EvaluateNodes(output, item.fetch);
2917   ASSERT_EQ(tensors.size(), 1);
2918 
2919   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2920   EXPECT_EQ(output.node_size(), item.graph.node_size());
2921   for (int i = 0; i < output.node_size(); ++i) {
2922     const NodeDef& node = output.node(i);
2923     if (node.name() == "output") {
2924       EXPECT_EQ(node.op(), "Mul");
2925       ASSERT_EQ(node.input_size(), 2);
2926       EXPECT_EQ(node.input(0), "x");
2927       EXPECT_EQ(node.input(1), "sqrt_y");
2928     } else if (node.name() == "sqrt_y") {
2929       EXPECT_EQ(node.op(), "Rsqrt");
2930       ASSERT_EQ(node.input_size(), 1);
2931       EXPECT_EQ(node.input(0), "y");
2932     }
2933   }
2934 }
2935 
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2936 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2937   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2938   Output floats = ops::Const(s.WithOpName("floats"),
2939                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2940   Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2941   Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2942   Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2943   Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2944 
2945   GrapplerItem item;
2946   item.fetch = {"grad", "output0"};
2947   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2948   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2949   ASSERT_EQ(tensors_expected.size(), 2);
2950 
2951   GraphDef output;
2952   ArithmeticOptimizer optimizer;
2953   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2954   OptimizeAndPrune(&optimizer, &item, &output);
2955   auto tensors = EvaluateNodes(output, item.fetch);
2956   ASSERT_EQ(tensors.size(), 2);
2957 
2958   for (int i = 0; i < tensors.size(); i++) {
2959     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2960     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2961   }
2962   EXPECT_EQ(output.node_size(), item.graph.node_size());
2963   for (int i = 0; i < output.node_size(); ++i) {
2964     const NodeDef& node = output.node(i);
2965     if (node.name() == "grad") {
2966       EXPECT_EQ(node.op(), "Div");
2967       ASSERT_EQ(node.input_size(), 2);
2968       EXPECT_EQ(node.input(0), "mul1");
2969       EXPECT_EQ(node.input(1), "output0");
2970     } else if (node.name() == "output0") {
2971       EXPECT_EQ(node.op(), "Sqrt");
2972       ASSERT_EQ(node.input_size(), 1);
2973       EXPECT_EQ(node.input(0), "floats");
2974     }
2975   }
2976 }
2977 
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2978 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2979   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2980   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2981   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2982   Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2983   Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2984 
2985   GrapplerItem item;
2986   item.fetch = {"output"};
2987   TF_CHECK_OK(s.ToGraphDef(&item.graph));
2988   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2989   ASSERT_EQ(tensors_expected.size(), 1);
2990 
2991   GraphDef output;
2992   ArithmeticOptimizer optimizer;
2993   EnableOnlySqrtDivToRsqrtMul(&optimizer);
2994   OptimizeAndPrune(&optimizer, &item, &output);
2995   auto tensors = EvaluateNodes(output, item.fetch);
2996   ASSERT_EQ(tensors.size(), 1);
2997 
2998   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2999   EXPECT_EQ(output.node_size(), item.graph.node_size());
3000   for (int i = 0; i < output.node_size(); ++i) {
3001     const NodeDef& node = output.node(i);
3002     if (node.name() == "output") {
3003       EXPECT_EQ(node.op(), "FloorDiv");
3004       ASSERT_EQ(node.input_size(), 2);
3005       EXPECT_EQ(node.input(0), "x");
3006       EXPECT_EQ(node.input(1), "sqrt_y");
3007     } else if (node.name() == "sqrt_y") {
3008       EXPECT_EQ(node.op(), "Sqrt");
3009       ASSERT_EQ(node.input_size(), 1);
3010       EXPECT_EQ(node.input(0), "y");
3011     }
3012   }
3013 }
3014 
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)3015 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
3016   for (bool is_complex : {false, true}) {
3017     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3018     Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3019     Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3020     Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
3021     Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
3022     Output sub_x_y =
3023         is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
3024                    : ops::Sub(s.WithOpName("sub_x_y"), x, y);
3025     Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3026 
3027     GrapplerItem item;
3028     item.fetch = {"output"};
3029     TF_CHECK_OK(s.ToGraphDef(&item.graph));
3030     const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3031     ASSERT_EQ(tensors_expected.size(), 1);
3032 
3033     GraphDef output;
3034     ArithmeticOptimizer optimizer;
3035     EnableOnlyFuseSquaredDiff(&optimizer);
3036     OptimizeAndPrune(&optimizer, &item, &output);
3037     const auto tensors = EvaluateNodes(output, item.fetch);
3038     ASSERT_EQ(tensors.size(), 1);
3039 
3040     if (is_complex) {
3041       test::ExpectTensorNear<std::complex<float>>(tensors[0],
3042                                                   tensors_expected[0], 1e-6);
3043       EXPECT_EQ(output.node_size(), item.graph.node_size());
3044     } else {
3045       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3046       // The two unused Complex nodes should get pruned.
3047       EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
3048     }
3049     for (int i = 0; i < output.node_size(); ++i) {
3050       const NodeDef& node = output.node(i);
3051       if (node.name() == "output") {
3052         EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
3053         ASSERT_EQ(node.input_size(), 1);
3054         EXPECT_EQ(node.input(0), "sub_x_y");
3055       } else if (node.name() == "sub_x_y") {
3056         EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
3057         ASSERT_EQ(node.input_size(), 2);
3058         EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
3059         EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
3060       }
3061     }
3062   }
3063 }
3064 
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)3065 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
3066   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3067   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3068   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3069   Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
3070   Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3071 
3072   GrapplerItem item;
3073   item.fetch = {"output", "sub_x_y"};
3074   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3075   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3076   ASSERT_EQ(tensors_expected.size(), 2);
3077 
3078   GraphDef output;
3079   ArithmeticOptimizer optimizer;
3080   EnableOnlyFuseSquaredDiff(&optimizer);
3081   OptimizeAndPrune(&optimizer, &item, &output);
3082   const auto tensors = EvaluateNodes(output, item.fetch);
3083   ASSERT_EQ(tensors.size(), 2);
3084 
3085   for (int i = 0; i < tensors.size(); i++) {
3086     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3087     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3088   }
3089   EXPECT_EQ(output.node_size(), item.graph.node_size());
3090   for (int i = 0; i < output.node_size(); ++i) {
3091     const NodeDef& node = output.node(i);
3092     if (node.name() == "output") {
3093       EXPECT_EQ(node.op(), "Square");
3094       ASSERT_EQ(node.input_size(), 1);
3095       EXPECT_EQ(node.input(0), "sub_x_y");
3096     } else if (node.name() == "sub_x_y") {
3097       EXPECT_EQ(node.op(), "Sub");
3098       ASSERT_EQ(node.input_size(), 2);
3099       EXPECT_EQ(node.input(0), "x");
3100       EXPECT_EQ(node.input(1), "y");
3101     }
3102   }
3103 }
3104 
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)3105 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
3106   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3107   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3108   Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
3109   Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
3110 
3111   GrapplerItem item;
3112   item.fetch = {"output"};
3113   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3114   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3115   ASSERT_EQ(tensors_expected.size(), 1);
3116 
3117   GraphDef output;
3118   ArithmeticOptimizer optimizer;
3119   EnableOnlyLogSoftmax(&optimizer);
3120   OptimizeAndPrune(&optimizer, &item, &output);
3121   const auto tensors = EvaluateNodes(output, item.fetch);
3122   ASSERT_EQ(tensors.size(), 1);
3123 
3124   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3125   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3126   for (int i = 0; i < output.node_size(); ++i) {
3127     const NodeDef& node = output.node(i);
3128     if (node.name() == "output") {
3129       EXPECT_EQ(node.op(), "LogSoftmax");
3130       ASSERT_EQ(node.input_size(), 1);
3131       EXPECT_EQ(node.input(0), "x");
3132     }
3133   }
3134 }
3135 
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)3136 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
3137   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3138   Output floats = ops::Const(s.WithOpName("floats"),
3139                              {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
3140   Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
3141   Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
3142 
3143   GrapplerItem item;
3144   item.fetch = {"softmax", "final_output"};
3145   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3146   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3147   ASSERT_EQ(tensors_expected.size(), 2);
3148 
3149   GraphDef output;
3150   ArithmeticOptimizer optimizer;
3151   EnableOnlyLogSoftmax(&optimizer);
3152   OptimizeTwice(&optimizer, &item, &output);
3153   const auto tensors = EvaluateNodes(output, item.fetch);
3154   ASSERT_EQ(tensors.size(), 2);
3155 
3156   // Should be a NoOp since we are not allowed to change the output of fetch
3157   // nodes.
3158   VerifyGraphsMatch(item.graph, output, __LINE__);
3159 
3160   for (int i = 0; i < tensors.size(); i++) {
3161     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3162     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3163   }
3164 }
3165 
TEST_F(ArithmeticOptimizerTest,ConvertPow)3166 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
3167   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3168   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3169   auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
3170   auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
3171   auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
3172   auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
3173   auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
3174   auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
3175   auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
3176   auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3177   auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
3178   auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
3179   auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
3180   Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
3181   Output out3 =
3182       ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
3183   Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
3184   Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
3185   Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
3186   Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
3187   Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
3188   Output out = ops::Pow(s.WithOpName("out"), x, y);
3189   Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
3190   Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
3191 
3192   GrapplerItem item;
3193   item.fetch = {"out2",   "out3",  "out1", "out.5",      "out0",
3194                 "out_.5", "out_1", "out",  "out_bcast1", "out_bcast2"};
3195   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3196   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3197   ASSERT_EQ(tensors_expected.size(), 10);
3198 
3199   GraphDef got;
3200   ArithmeticOptimizer optimizer;
3201   EnableOnlyConvertPow(&optimizer);
3202   OptimizeAndPrune(&optimizer, &item, &got);
3203   auto tensors = EvaluateNodes(got, item.fetch);
3204   ASSERT_EQ(tensors.size(), 10);
3205 
3206   for (int i = 0; i < tensors.size(); ++i) {
3207     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3208     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3209   }
3210 
3211   GraphDef want;
3212   AddNode("x", "Const", {}, {}, &want);
3213   AddNode("y", "Const", {}, {}, &want);
3214   AddNode("z", "Const", {}, {}, &want);
3215   AddNode("ones", "Const", {}, {}, &want);
3216   AddNode("zeros", "Const", {}, {}, &want);
3217   AddNode("out2", "Square", {"x"}, {}, &want);
3218   AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
3219           &want)
3220       ->set_device("/device:CPU:0");
3221   AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
3222           {}, &want)
3223       ->set_device("/device:CPU:0");
3224   AddNode("out1", "Identity", {"x"}, {}, &want);
3225   AddNode("out.5", "Sqrt", {"x"}, {}, &want);
3226   AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
3227   AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
3228   AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
3229   AddNode("out", "Pow", {"x", "y"}, {}, &want);
3230   AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
3231   AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
3232 
3233   CompareGraphs(want, got);
3234 }
3235 
TEST_F(ArithmeticOptimizerTest,Log1p)3236 TEST_F(ArithmeticOptimizerTest, Log1p) {
3237   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3238 
3239   auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
3240   auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
3241   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3242   auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
3243   auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
3244   Output out1 = ops::Log(s.WithOpName("out1"), a12);
3245   Output out2 = ops::Log(s.WithOpName("out2"), a23);
3246 
3247   GrapplerItem item;
3248   item.fetch = {"out1", "out2"};
3249   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3250   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3251   ASSERT_EQ(tensors_expected.size(), 2);
3252 
3253   GraphDef got;
3254   ArithmeticOptimizer optimizer;
3255   EnableOnlyLog1p(&optimizer);
3256   OptimizeAndPrune(&optimizer, &item, &got);
3257   auto tensors = EvaluateNodes(got, item.fetch);
3258   ASSERT_EQ(tensors.size(), 2);
3259 
3260   for (int i = 0; i < 2; ++i) {
3261     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3262     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3263   }
3264 
3265   GraphDef want;
3266   AddNode("x2", "Const", {}, {}, &want);
3267   AddNode("x3", "Const", {}, {}, &want);
3268   AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
3269   AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
3270   AddNode("out2", "Log", {"a23"}, {}, &want);
3271 
3272   CompareGraphs(want, got);
3273 }
3274 
TEST_F(ArithmeticOptimizerTest,Expm1)3275 TEST_F(ArithmeticOptimizerTest, Expm1) {
3276   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3277 
3278   auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
3279   auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
3280   auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3281   auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
3282   Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
3283   Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
3284 
3285   GrapplerItem item;
3286   item.fetch = {"out1", "out2"};
3287   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3288   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3289   ASSERT_EQ(tensors_expected.size(), 2);
3290 
3291   GraphDef got;
3292   ArithmeticOptimizer optimizer;
3293   EnableOnlyExpm1(&optimizer);
3294   OptimizeAndPrune(&optimizer, &item, &got);
3295   auto tensors = EvaluateNodes(got, item.fetch);
3296   ASSERT_EQ(tensors.size(), 2);
3297 
3298   for (int i = 0; i < 2; ++i) {
3299     EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3300     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3301   }
3302 
3303   GraphDef want;
3304   AddNode("x1", "Const", {}, {}, &want);
3305   AddNode("x3", "Const", {}, {}, &want);
3306   AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
3307   AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
3308   AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
3309 
3310   CompareGraphs(want, got);
3311 }
3312 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)3313 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
3314   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3315 
3316   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3317   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
3318   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3319 
3320   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3321   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3322 
3323   auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
3324 
3325   GrapplerItem item;
3326   item.fetch = {"outputs"};
3327   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3328 
3329   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3330   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3331   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3332   std::vector<std::pair<string, Tensor>> feed = {
3333       {"a", a_t}, {"b", b_t}, {"c", c_t}};
3334   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3335   ASSERT_EQ(tensors_expected.size(), 1);
3336 
3337   GraphDef output;
3338   ArithmeticOptimizer optimizer;
3339   EnableOnlyMinimizeBroadcasts(&optimizer);
3340 
3341   OptimizeAndPrune(&optimizer, &item, &output);
3342 
3343   // We expect the following rewrite(s) to occur:
3344   //
3345   //     *                  *
3346   //    / \                / \
3347   //   *   c      -->     *   b
3348   //  / \                / \
3349   // a   b              a   c
3350   NodeMap node_map(&output);
3351 
3352   const NodeDef* mul1_node = node_map.GetNode("mul1");
3353   ASSERT_NE(mul1_node, nullptr);
3354   ASSERT_EQ(mul1_node->input_size(), 2);
3355   EXPECT_EQ(mul1_node->input(0), "a");
3356   EXPECT_EQ(mul1_node->input(1), "c");
3357 
3358   const NodeDef* mul2_node = node_map.GetNode("mul2");
3359   ASSERT_NE(mul2_node, nullptr);
3360   ASSERT_EQ(mul2_node->input_size(), 2);
3361   EXPECT_EQ(mul2_node->input(0), "mul1");
3362   EXPECT_EQ(mul2_node->input(1), "b");
3363 
3364   auto tensors = EvaluateNodes(output, item.fetch, feed);
3365   ASSERT_EQ(tensors.size(), 1);
3366   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3367 }
3368 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)3369 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
3370   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3371 
3372   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
3373   auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
3374   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
3375   auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
3376   auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
3377 
3378   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3379   auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3380   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
3381   auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
3382 
3383   auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
3384 
3385   GrapplerItem item;
3386   item.fetch = {"outputs"};
3387   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3388 
3389   auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3390   auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
3391   auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3392   auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3393   auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3394   std::vector<std::pair<string, Tensor>> feed = {
3395       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
3396   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3397   ASSERT_EQ(tensors_expected.size(), 1);
3398 
3399   GraphDef output;
3400   ArithmeticOptimizer optimizer;
3401   EnableOnlyMinimizeBroadcasts(&optimizer);
3402 
3403   OptimizeAndPrune(&optimizer, &item, &output);
3404 
3405   // We expect the following rewrite(s) to occur: Graph is "flattened" and
3406   // largest shape pushed to the top.
3407   //
3408   //          *
3409   //        /   \
3410   //       *     e                *
3411   //      /  \                  /   \
3412   //     *    d               *      b
3413   //    / \                 /  \
3414   //   *   c      -->     *      *
3415   //  / \                / \    / \
3416   // a   b              a   c  d   e
3417   NodeMap node_map(&output);
3418 
3419   const NodeDef* mul1_node = node_map.GetNode("mul1");
3420   ASSERT_NE(mul1_node, nullptr);
3421   ASSERT_EQ(mul1_node->input_size(), 2);
3422   EXPECT_EQ(mul1_node->input(0), "a");
3423   EXPECT_EQ(mul1_node->input(1), "c");
3424 
3425   const NodeDef* mul2_node = node_map.GetNode("mul2");
3426   ASSERT_NE(mul2_node, nullptr);
3427   ASSERT_EQ(mul2_node->input_size(), 2);
3428   EXPECT_EQ(mul2_node->input(0), "d");
3429   EXPECT_EQ(mul2_node->input(1), "e");
3430 
3431   const NodeDef* mul3_node = node_map.GetNode("mul3");
3432   ASSERT_NE(mul3_node, nullptr);
3433   ASSERT_EQ(mul3_node->input_size(), 2);
3434   EXPECT_EQ(mul3_node->input(0), "mul1");
3435   EXPECT_EQ(mul3_node->input(1), "mul2");
3436 
3437   const NodeDef* mul4_node = node_map.GetNode("mul4");
3438   ASSERT_NE(mul4_node, nullptr);
3439   ASSERT_EQ(mul4_node->input_size(), 2);
3440   EXPECT_EQ(mul4_node->input(0), "mul3");
3441   EXPECT_EQ(mul4_node->input(1), "b");
3442 
3443   auto tensors = EvaluateNodes(output, item.fetch, feed);
3444   ASSERT_EQ(tensors.size(), 1);
3445   test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
3446 }
3447 
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)3448 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
3449   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3450 
3451   // [a, b, c] - scalars, [d] - matrix
3452   auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3453   auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
3454   auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3455   auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
3456 
3457   auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3458   auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3459   auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3460 
3461   auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3462 
3463   GrapplerItem item;
3464   item.fetch = {"outputs"};
3465   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3466 
3467   auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3468   auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3469   auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3470   auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3471   std::vector<std::pair<string, Tensor>> feed = {
3472       {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3473   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3474   ASSERT_EQ(tensors_expected.size(), 1);
3475 
3476   GraphDef output;
3477   ArithmeticOptimizer optimizer;
3478   EnableOnlyMinimizeBroadcasts(&optimizer);
3479 
3480   OptimizeAndPrune(&optimizer, &item, &output);
3481 
3482   // We expect the following rewrite(s) to occur:
3483   //
3484   //                              *
3485   //                            /  \
3486   //       *                   *    D
3487   //     /   \                / \
3488   //    *     *      ->      *   c
3489   //   / \   / \            / \
3490   //  a   b c   D          a   b
3491   NodeMap node_map(&output);
3492 
3493   const NodeDef* mul1_node = node_map.GetNode("mul2");
3494   ASSERT_NE(mul1_node, nullptr);
3495   ASSERT_EQ(mul1_node->input_size(), 2);
3496   EXPECT_EQ(mul1_node->input(0), "a");
3497   EXPECT_EQ(mul1_node->input(1), "b");
3498 
3499   const NodeDef* mul2_node = node_map.GetNode("mul1");
3500   ASSERT_NE(mul2_node, nullptr);
3501   ASSERT_EQ(mul2_node->input_size(), 2);
3502   EXPECT_EQ(mul2_node->input(0), "mul2");
3503   EXPECT_EQ(mul2_node->input(1), "c");
3504 
3505   const NodeDef* mul3_node = node_map.GetNode("mul3");
3506   ASSERT_NE(mul3_node, nullptr);
3507   ASSERT_EQ(mul3_node->input_size(), 2);
3508   EXPECT_EQ(mul3_node->input(0), "D");
3509   EXPECT_EQ(mul3_node->input(1), "mul1");
3510 
3511   auto tensors = EvaluateNodes(output, item.fetch, feed);
3512   ASSERT_EQ(tensors.size(), 1);
3513   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3514 }
3515 
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3516 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3517   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3518   Output weights1 = ops::Const(s.WithOpName("weights1"),
3519                                Input::Initializer(1.0f, {5, 5, 3, 4}));
3520   Output weights2 = ops::Const(s.WithOpName("weights2"),
3521                                Input::Initializer(2.0f, {5, 5, 3, 4}));
3522   Output biases =
3523       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3524   Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3525   Output input = ops::Const(s.WithOpName("input"),
3526                             Input::Initializer(1.0f, {1, 28, 28, 3}));
3527   Output branch1 =
3528       ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3529   branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3530   branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3531   Output branch2 =
3532       ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3533   branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3534   branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3535   Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3536   Output output = ops::Identity(s.WithOpName("output"), concat);
3537 
3538   GrapplerItem item;
3539   item.fetch = {"output"};
3540   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3541 
3542   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3543 
3544   GraphDef new_graph;
3545   ArithmeticOptimizer optimizer;
3546   OptimizeAndPrune(&optimizer, &item, &new_graph);
3547 
3548   // Verify that the two Relus are not hoisted.
3549   EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3550 
3551   auto tensors = EvaluateNodes(new_graph, item.fetch);
3552   for (int i = 0; i < item.fetch.size(); ++i) {
3553     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3554   }
3555 }
3556 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3557 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3558   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3559   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3560   Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3561   Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3562   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3563   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3564   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3565   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3566   // Test case with chains of length 1.
3567   // Rewrites
3568   //       Concat({Exp(a), Exp(b), Exp(c)})
3569   // into
3570   //       Exp(Concat({a, b, c})).
3571   Output sin_a =
3572       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3573   Output exp_a =
3574       ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3575   Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3576   Output exp_c =
3577       ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3578   Output concat =
3579       ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3580   Output id = ops::Identity(s.WithOpName("id"), concat);
3581 
3582   // Test case with chains of length 2.
3583   // Rewrites
3584   //       Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3585   // into
3586   //       Cos(Exp(Concat({a, b, c}))).
3587   Output exp_a2 =
3588       ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3589   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3590   Output exp_c2 =
3591       ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3592   Output cos_exp_a2 = ops::Cos(
3593       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3594   Output cos_exp_b2 = ops::Cos(
3595       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3596   Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3597   Output concat2 = ops::Concat(s.WithOpName("concat2"),
3598                                {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3599   Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3600   GrapplerItem item;
3601   item.fetch = {"id", "id2"};
3602   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3603 
3604   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3605 
3606   GraphDef output;
3607   ArithmeticOptimizer optimizer;
3608   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3609   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3610 
3611   int found = 0;
3612   for (const NodeDef& node : output.node()) {
3613     if (node.name() == "concat") {
3614       ASSERT_EQ(node.input_size(), 4);
3615       EXPECT_EQ(node.input(0), "sin_a");
3616       EXPECT_EQ(node.input(1), "b");
3617       EXPECT_EQ(node.input(2), "c");
3618       EXPECT_EQ(node.input(3), "axis");
3619       found++;
3620     }
3621     if (node.name() == "exp_a") {
3622       ASSERT_EQ(node.input_size(), 1);
3623       EXPECT_EQ(node.input(0), "concat");
3624       found++;
3625     }
3626     if (node.name() == "id") {
3627       ASSERT_EQ(node.input_size(), 1);
3628       EXPECT_EQ(node.input(0), "exp_a");
3629       found++;
3630     }
3631 
3632     if (node.name() == "concat2") {
3633       ASSERT_EQ(node.input_size(), 4);
3634       EXPECT_EQ(node.input(0), "sin_a");
3635       EXPECT_EQ(node.input(1), "b");
3636       EXPECT_EQ(node.input(2), "c");
3637       EXPECT_EQ(node.input(3), "axis");
3638       found++;
3639     }
3640     if (node.name() == "exp_a2") {
3641       ASSERT_EQ(node.input_size(), 1);
3642       EXPECT_EQ(node.input(0), "concat2");
3643       found++;
3644     }
3645     if (node.name() == "cos_exp_a2") {
3646       ASSERT_EQ(node.input_size(), 1);
3647       EXPECT_EQ(node.input(0), "exp_a2");
3648       found++;
3649     }
3650     if (node.name() == "id2") {
3651       ASSERT_EQ(node.input_size(), 1);
3652       EXPECT_EQ(node.input(0), "cos_exp_a2");
3653       found++;
3654     }
3655   }
3656   EXPECT_EQ(found, 7);
3657 
3658   auto tensors = EvaluateNodes(output, item.fetch);
3659   ASSERT_EQ(tensors.size(), tensors_expected.size());
3660   EXPECT_EQ(tensors.size(), item.fetch.size());
3661   for (int i = 0; i < item.fetch.size(); ++i) {
3662     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3663   }
3664 }
3665 
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3666 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3667   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3668   Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3669   Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3670   Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3671   Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3672   Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3673   // Test case with chains of length 1.
3674   // Rewrites
3675   //          [Sin(y) for y in Split(x)]
3676   // into
3677   //          [y for y in Split(Sin(x))].
3678   ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3679   Output sin_a =
3680       ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3681   Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3682   Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3683   Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3684   Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3685 
3686   // Test case with SplitV and chains of length 2.
3687   // Rewrites
3688   //          [Cos(Exp(y)) for y in Split(x)]
3689   // into
3690   //          [y for y in Split(Cos(Exp(x)))].
3691   Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3692   ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3693   Output exp_a2 = ops::Exp(
3694       s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3695   Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3696   Output cos_exp_a2 = ops::Cos(
3697       s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3698   Output cos_exp_b2 = ops::Cos(
3699       s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3700   Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3701   Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3702 
3703   GrapplerItem item;
3704   item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3705   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3706 
3707   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3708 
3709   GraphDef output;
3710   ArithmeticOptimizer optimizer;
3711   EnableOnlyHoistCWiseUnaryChains(&optimizer);
3712   OptimizeTwiceAndPrune(&optimizer, &item, &output);
3713 
3714   int found = 0;
3715   for (const NodeDef& node : output.node()) {
3716     // The following 6 nodes should be pruned.
3717     EXPECT_NE(node.name(), "sin_a");
3718     EXPECT_NE(node.name(), "sin_b");
3719     EXPECT_NE(node.name(), "exp_a2");
3720     EXPECT_NE(node.name(), "exp_b2");
3721     EXPECT_NE(node.name(), "cos_exp_a2");
3722     EXPECT_NE(node.name(), "cos_exp_b2");
3723 
3724     if (node.name() == "split1") {
3725       ASSERT_EQ(node.input_size(), 2);
3726       EXPECT_EQ(node.input(0), "axis");
3727       EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3728       found++;
3729     }
3730     if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3731       EXPECT_EQ(node.op(), "Sin");
3732       ASSERT_EQ(node.input_size(), 1);
3733       EXPECT_EQ(node.input(0), "x");
3734       found++;
3735     }
3736     if (node.name() == "id_a") {
3737       ASSERT_EQ(node.input_size(), 1);
3738       EXPECT_EQ(node.input(0), "split1");
3739       found++;
3740     }
3741     if (node.name() == "exp_b") {
3742       ASSERT_EQ(node.input_size(), 1);
3743       EXPECT_EQ(node.input(0), "split1:1");
3744       found++;
3745     }
3746     if (node.name() == "id_b") {
3747       ASSERT_EQ(node.input_size(), 1);
3748       EXPECT_EQ(node.input(0), "exp_b");
3749       found++;
3750     }
3751     if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3752       EXPECT_EQ(node.op(), "Exp");
3753       ASSERT_EQ(node.input_size(), 1);
3754       EXPECT_EQ(node.input(0), "x");
3755       found++;
3756     }
3757     if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3758       EXPECT_EQ(node.op(), "Cos");
3759       ASSERT_EQ(node.input_size(), 1);
3760       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3761       found++;
3762     }
3763     if (node.name() == "split2") {
3764       ASSERT_EQ(node.input_size(), 3);
3765       EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3766       EXPECT_EQ(node.input(1), "size_splits2");
3767       EXPECT_EQ(node.input(2), "axis");
3768       found++;
3769     }
3770     if (node.name() == "id_a2") {
3771       ASSERT_EQ(node.input_size(), 1);
3772       EXPECT_EQ(node.input(0), "split2");
3773       found++;
3774     }
3775     if (node.name() == "id_b2") {
3776       ASSERT_EQ(node.input_size(), 1);
3777       EXPECT_EQ(node.input(0), "split2:1");
3778       found++;
3779     }
3780   }
3781   EXPECT_EQ(found, 10);
3782 
3783   auto tensors = EvaluateNodes(output, item.fetch);
3784   ASSERT_EQ(tensors.size(), tensors_expected.size());
3785   EXPECT_EQ(tensors.size(), item.fetch.size());
3786   for (int i = 0; i < item.fetch.size(); ++i) {
3787     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3788   }
3789 }
3790 
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3791 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3792   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3793   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3794   Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3795   Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3796   Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3797   Output id1 = ops::Identity(s.WithOpName("id1"), a);
3798   Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3799   Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3800   GrapplerItem item;
3801   item.fetch = {"out1", "out2"};
3802   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3803 
3804   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3805 
3806   GraphDef output;
3807   ArithmeticOptimizer optimizer;
3808   EnableOnlyRemoveIdempotent(&optimizer);
3809   OptimizeTwice(&optimizer, &item, &output);
3810 
3811   EXPECT_EQ(7, output.node_size());
3812   int found = 0;
3813   for (const NodeDef& node : output.node()) {
3814     if (node.name() == "out1") {
3815       ASSERT_EQ(node.input_size(), 1);
3816       EXPECT_EQ(node.input(0), "sn1");
3817       found++;
3818     } else if (node.name() == "out2") {
3819       ASSERT_EQ(node.input_size(), 1);
3820       EXPECT_EQ(node.input(0), "id1");
3821       found++;
3822     } else if (node.name() == "sn1") {
3823       ASSERT_EQ(node.input_size(), 1);
3824       EXPECT_EQ(node.input(0), "a");
3825       found++;
3826     }
3827   }
3828   EXPECT_EQ(found, 3);
3829 
3830   auto tensors = EvaluateNodes(output, item.fetch);
3831   ASSERT_EQ(tensors.size(), tensors_expected.size());
3832   EXPECT_EQ(tensors.size(), item.fetch.size());
3833   for (int i = 0; i < item.fetch.size(); ++i) {
3834     test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3835   }
3836 }
3837 
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3838 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3839   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3840   Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3841   Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3842   Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3843   Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3844   Output lt = ops::Less(s.WithOpName("lt"), a, b);
3845   Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3846   Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3847   Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3848   // not_eq is reserved
3849   Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3850   Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3851   Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3852   Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3853   Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3854   Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3855   Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3856   Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3857   Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3858   Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3859   Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3860   Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3861 
3862   GrapplerItem item;
3863   item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3864                 "id_not_le", "id_not_gt",  "id_not_ge"};
3865   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3866 
3867   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3868 
3869   GraphDef output;
3870   ArithmeticOptimizer optimizer;
3871   EnableOnlyRemoveLogicalNot(&optimizer);
3872   OptimizeTwice(&optimizer, &item, &output);
3873 
3874   int found = 0;
3875   for (const NodeDef& node : output.node()) {
3876     if (node.name() == "id_not_eq") {
3877       ASSERT_EQ(node.input_size(), 1);
3878       EXPECT_EQ(node.input(0), "eq");
3879       ++found;
3880     }
3881     if (node.name() == "id_not_neq") {
3882       ASSERT_EQ(node.input_size(), 1);
3883       EXPECT_EQ(node.input(0), "neq");
3884       ++found;
3885     }
3886     if (node.name() == "id_not_lt") {
3887       ASSERT_EQ(node.input_size(), 1);
3888       EXPECT_EQ(node.input(0), "lt");
3889       ++found;
3890     }
3891     if (node.name() == "id_not_le") {
3892       ASSERT_EQ(node.input_size(), 1);
3893       EXPECT_EQ(node.input(0), "le");
3894       ++found;
3895     }
3896     if (node.name() == "id_not_gt") {
3897       ASSERT_EQ(node.input_size(), 1);
3898       EXPECT_EQ(node.input(0), "gt");
3899       ++found;
3900     }
3901     if (node.name() == "id_not_ge") {
3902       ASSERT_EQ(node.input_size(), 1);
3903       EXPECT_EQ(node.input(0), "ge");
3904       ++found;
3905     }
3906 
3907     if (node.name() == "eq") {
3908       EXPECT_EQ(node.op(), "NotEqual");
3909       ++found;
3910     }
3911     if (node.name() == "neq") {
3912       EXPECT_EQ(node.op(), "Equal");
3913       ++found;
3914     }
3915     if (node.name() == "lt") {
3916       EXPECT_EQ(node.op(), "GreaterEqual");
3917       ++found;
3918     }
3919     if (node.name() == "le") {
3920       EXPECT_EQ(node.op(), "Greater");
3921       ++found;
3922     }
3923     if (node.name() == "gt") {
3924       EXPECT_EQ(node.op(), "LessEqual");
3925       ++found;
3926     }
3927     if (node.name() == "ge") {
3928       EXPECT_EQ(node.op(), "Less");
3929       ++found;
3930     }
3931   }
3932   EXPECT_EQ(found, 12);
3933 
3934   auto tensors = EvaluateNodes(output, item.fetch);
3935   ASSERT_EQ(tensors.size(), tensors_expected.size());
3936   EXPECT_EQ(tensors.size(), item.fetch.size());
3937   for (int i = 0; i < item.fetch.size(); ++i) {
3938     test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3939   }
3940 }
3941 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3942 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3943   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3944   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3945   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3946   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3947   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3948 
3949   GrapplerItem item;
3950   item.fetch = {"final_out"};
3951   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3952   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3953   ASSERT_EQ(tensors_expected.size(), 1);
3954 
3955   GraphDef output;
3956   ArithmeticOptimizer optimizer;
3957   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3958   OptimizeAndPrune(&optimizer, &item, &output);
3959   auto tensors = EvaluateNodes(output, item.fetch);
3960   ASSERT_EQ(tensors.size(), 1);
3961 
3962   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3963   EXPECT_EQ(output.node_size(), item.graph.node_size());
3964   // Check if the inputs are switched
3965   int required_node_count = 0;
3966   for (int i = 0; i < output.node_size(); ++i) {
3967     const NodeDef& node = output.node(i);
3968     if (node.name() == "sqrt") {
3969       EXPECT_EQ(node.op(), "Sqrt");
3970       ASSERT_EQ(node.input_size(), 1);
3971       EXPECT_EQ(node.input(0), "reduce_max");
3972       ++required_node_count;
3973     } else if (node.name() == "reduce_max") {
3974       EXPECT_EQ(node.op(), "Max");
3975       ASSERT_EQ(node.input_size(), 2);
3976       EXPECT_EQ(node.input(0), "x");
3977       ++required_node_count;
3978     }
3979   }
3980   EXPECT_EQ(required_node_count, 2);
3981 }
3982 
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3983 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3984   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3985   const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3986   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3987   Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3988   Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3989 
3990   GrapplerItem item;
3991   item.fetch = {"final_out"};
3992   TF_CHECK_OK(s.ToGraphDef(&item.graph));
3993   const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3994   ASSERT_EQ(tensors_expected.size(), 1);
3995 
3996   GraphDef output;
3997   ArithmeticOptimizer optimizer;
3998   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3999   OptimizeAndPrune(&optimizer, &item, &output);
4000   const auto tensors = EvaluateNodes(output, item.fetch);
4001   ASSERT_EQ(tensors.size(), 1);
4002 
4003   test::ExpectTensorEqual<int64_t>(tensors[0], tensors_expected[0]);
4004   EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
4005   // Check if the inputs are switched
4006   int required_node_count = 0;
4007   for (int i = 0; i < output.node_size(); ++i) {
4008     const NodeDef& node = output.node(i);
4009     if (node.name() == "final_out") {
4010       EXPECT_EQ(node.op(), "Identity");
4011       ASSERT_EQ(node.input_size(), 1);
4012       EXPECT_EQ(node.input(0), "arg_max");
4013       ++required_node_count;
4014     } else if (node.name() == "arg_max") {
4015       EXPECT_EQ(node.op(), "ArgMax");
4016       ASSERT_EQ(node.input_size(), 2);
4017       EXPECT_EQ(node.input(0), "x");
4018       ++required_node_count;
4019     }
4020   }
4021   EXPECT_EQ(required_node_count, 2);
4022 }
4023 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)4024 TEST_F(ArithmeticOptimizerTest,
4025        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
4026   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4027   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4028   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4029   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
4030   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4031 
4032   GrapplerItem item;
4033   item.fetch = {"sqrt", "final_out"};
4034   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4035   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4036   EXPECT_EQ(tensors_expected.size(), 2);
4037 
4038   GraphDef output;
4039   ArithmeticOptimizer optimizer;
4040   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4041   OptimizeTwice(&optimizer, &item, &output);
4042 
4043   // Should be a NoOp since we are not allowed to change the output of fetch
4044   // nodes.
4045   VerifyGraphsMatch(item.graph, output, __LINE__);
4046 }
4047 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)4048 TEST_F(ArithmeticOptimizerTest,
4049        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
4050   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4051   auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
4052   Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
4053   Output y = ops::Neg(s.WithOpName("y"), reshape);
4054   Output z = ops::Max(s.WithOpName("z"), y, {0});
4055 
4056   GrapplerItem item;
4057   item.fetch = {"z"};
4058   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4059   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4060   ASSERT_EQ(tensors_expected.size(), 1);
4061 
4062   GraphDef output;
4063   ArithmeticOptimizer optimizer;
4064   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4065   OptimizeTwice(&optimizer, &item, &output);
4066 
4067   // Should be a NoOp since we are not allowed to change the output of fetch
4068   // nodes.
4069   VerifyGraphsMatch(item.graph, output, __LINE__);
4070 
4071   auto tensors = EvaluateNodes(output, item.fetch);
4072   ASSERT_EQ(tensors.size(), 1);
4073   test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
4074   test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
4075 }
4076 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeSegmentMaxOrMinOps)4077 TEST_F(ArithmeticOptimizerTest,
4078        OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeSegmentMaxOrMinOps) {
4079   constexpr absl::string_view kSegmentMaxOpName = "SegmentMax";
4080   constexpr absl::string_view kUnsortedSegmentMaxOpName = "UnsortedSegmentMax";
4081   constexpr absl::string_view kSegmentMinOpName = "SegmentMin";
4082   constexpr absl::string_view kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
4083   constexpr absl::string_view segment_max_or_min_op_names[] = {
4084       kSegmentMaxOpName, kUnsortedSegmentMaxOpName, kSegmentMinOpName,
4085       kUnsortedSegmentMinOpName};
4086   for (const absl::string_view segment_op_name : segment_max_or_min_op_names) {
4087     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4088     Input x = ops::Const(s.WithOpName("x"), {-1.0f, 2.0f, -3.0f, 4.0f}, {2, 2});
4089     Input segment_ids = ops::Const(s.WithOpName("x"), {0, 2}, {2});
4090     Output relu = ops::Relu(s.WithOpName("relu"), x);
4091     Output segment_op;
4092     if (segment_op_name == kSegmentMaxOpName) {
4093       segment_op =
4094           ops::SegmentMax(s.WithOpName(segment_op_name), relu, segment_ids);
4095     } else if (segment_op_name == kUnsortedSegmentMaxOpName) {
4096       segment_op = ops::UnsortedSegmentMax(s.WithOpName(segment_op_name), relu,
4097                                            segment_ids, 3);
4098     } else if (segment_op_name == kSegmentMinOpName) {
4099       segment_op =
4100           ops::SegmentMin(s.WithOpName(segment_op_name), relu, segment_ids);
4101     } else {
4102       segment_op = ops::UnsortedSegmentMin(s.WithOpName(segment_op_name), relu,
4103                                            segment_ids, 3);
4104     }
4105     Output final_out = ops::Identity(s.WithOpName("final_out"), segment_op);
4106 
4107     GrapplerItem item;
4108     item.fetch = {"relu", "final_out"};
4109     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4110     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4111     EXPECT_EQ(tensors_expected.size(), 2);
4112 
4113     GraphDef output;
4114     ArithmeticOptimizer optimizer;
4115     EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4116     OptimizeTwice(&optimizer, &item, &output);
4117 
4118     VerifyGraphsMatch(item.graph, output, __LINE__);
4119   }
4120 }
4121 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)4122 TEST_F(ArithmeticOptimizerTest,
4123        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
4124   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4125   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4126   Output neg = ops::Neg(s.WithOpName("neg"), x);
4127   Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
4128   Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4129 
4130   GrapplerItem item;
4131   item.fetch = {"final_out"};
4132   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4133   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4134   ASSERT_EQ(tensors_expected.size(), 1);
4135 
4136   GraphDef output;
4137   ArithmeticOptimizer optimizer;
4138   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4139   OptimizeAndPrune(&optimizer, &item, &output);
4140   auto tensors = EvaluateNodes(output, item.fetch);
4141   ASSERT_EQ(tensors.size(), 1);
4142 
4143   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4144   EXPECT_EQ(output.node_size(), item.graph.node_size());
4145   // Check if the inputs are switched
4146   int required_node_count = 0;
4147   for (int i = 0; i < output.node_size(); ++i) {
4148     const NodeDef& node = output.node(i);
4149     if (node.name() == "neg") {
4150       EXPECT_EQ(node.op(), "Neg");
4151       ASSERT_EQ(node.input_size(), 1);
4152       EXPECT_EQ(node.input(0), "reduce_max");
4153       ++required_node_count;
4154     } else if (node.name() == "reduce_max") {
4155       EXPECT_EQ(node.op(), "Min");
4156       ASSERT_EQ(node.input_size(), 2);
4157       EXPECT_EQ(node.input(0), "x");
4158       ++required_node_count;
4159     }
4160   }
4161   EXPECT_EQ(2, required_node_count);
4162 }
4163 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)4164 TEST_F(ArithmeticOptimizerTest,
4165        OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
4166   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4167   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4168   Output neg = ops::Neg(s.WithOpName("neg"), x);
4169   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
4170                                  {1, 2, 2, 1}, "VALID");
4171 
4172   GrapplerItem item;
4173   item.fetch = {"max_pool"};
4174   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4175   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4176   ASSERT_EQ(tensors_expected.size(), 1);
4177 
4178   GraphDef output;
4179   ArithmeticOptimizer optimizer;
4180   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4181   OptimizeTwice(&optimizer, &item, &output);
4182 
4183   // Should be a NoOp
4184   VerifyGraphsMatch(item.graph, output, __LINE__);
4185 
4186   auto tensors = EvaluateNodes(output, item.fetch);
4187   ASSERT_EQ(tensors.size(), 1);
4188   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4189 }
4190 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)4191 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
4192   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4193   Output weights = ops::Const(s.WithOpName("weights"),
4194                               Input::Initializer(1.0f, {5, 5, 3, 4}));
4195   Output biases =
4196       ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
4197   Output input = ops::Const(s.WithOpName("input"),
4198                             Input::Initializer(1.0f, {1, 28, 28, 3}));
4199   Output output =
4200       ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
4201   output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
4202   output = ops::Relu(s.WithOpName("relu"), output);
4203   output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
4204                         {1, 2, 2, 1}, "VALID");
4205   output = ops::Identity(s.WithOpName("output"), output);
4206 
4207   GrapplerItem item;
4208   item.fetch = {"output"};
4209   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4210   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4211   ASSERT_EQ(tensors_expected.size(), 1);
4212 
4213   GraphDef new_graph;
4214   ArithmeticOptimizer optimizer;
4215   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4216   OptimizeTwice(&optimizer, &item, &new_graph);
4217 
4218   // Should be a NoOp
4219   VerifyGraphsMatch(item.graph, new_graph, __LINE__);
4220 
4221   auto tensors = EvaluateNodes(new_graph, item.fetch);
4222   ASSERT_EQ(tensors.size(), 1);
4223   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4224 }
4225 
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)4226 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
4227   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4228   auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4229   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4230   Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
4231                                  {1, 2, 2, 1}, "VALID");
4232   Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
4233 
4234   GrapplerItem item;
4235   item.fetch = {"final_out"};
4236   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4237   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4238   ASSERT_EQ(tensors_expected.size(), 1);
4239 
4240   GraphDef output;
4241   ArithmeticOptimizer optimizer;
4242   EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4243   OptimizeAndPrune(&optimizer, &item, &output);
4244   auto tensors = EvaluateNodes(output, item.fetch);
4245   ASSERT_EQ(tensors.size(), 1);
4246 
4247   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4248   EXPECT_EQ(output.node_size(), item.graph.node_size());
4249   // Check if the inputs are switched
4250   int required_node_count = 0;
4251   for (int i = 0; i < output.node_size(); ++i) {
4252     const NodeDef& node = output.node(i);
4253     if (node.name() == "sqrt") {
4254       EXPECT_EQ(node.op(), "Sqrt");
4255       ASSERT_EQ(node.input_size(), 1);
4256       EXPECT_EQ(node.input(0), "max_pool");
4257       ++required_node_count;
4258     } else if (node.name() == "max_pool") {
4259       EXPECT_EQ(node.op(), "MaxPool");
4260       ASSERT_EQ(node.input_size(), 1);
4261       EXPECT_EQ(node.input(0), "x");
4262       ++required_node_count;
4263     }
4264   }
4265   EXPECT_EQ(required_node_count, 2);
4266 }
4267 
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)4268 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
4269   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4270 
4271   auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4272   Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4273   Output log = ops::Log(s.WithOpName("log"), sqrt);
4274   Output relu = ops::Relu(s.WithOpName("relu"), log);
4275   Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
4276 
4277   GrapplerItem item;
4278   item.fetch = {"final_out"};
4279   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4280 
4281   // Place all nodes on CPU.
4282   for (int i = 0; i < item.graph.node_size(); ++i) {
4283     item.graph.mutable_node(i)->set_device("/device:CPU:0");
4284   }
4285 
4286   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4287   ASSERT_EQ(tensors_expected.size(), 1);
4288 
4289   GraphDef output;
4290   ArithmeticOptimizer optimizer;
4291   EnableOnlyUnaryOpsComposition(&optimizer);
4292   OptimizeAndPrune(&optimizer, &item, &output);
4293 
4294   EXPECT_EQ(output.node_size(), 3);
4295 
4296   // Check that Sqrt/Log/Relu were replaced with a single op.
4297   int required_node_count = 0;
4298   for (int i = 0; i < output.node_size(); ++i) {
4299     const NodeDef& node = output.node(i);
4300     if (node.name() == "final_out") {
4301       EXPECT_EQ(node.op(), "Identity");
4302       ASSERT_EQ(node.input_size(), 1);
4303       EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
4304       ++required_node_count;
4305     } else if (node.name() == "relu/unary_ops_composition") {
4306       EXPECT_EQ(node.op(), "_UnaryOpsComposition");
4307       ASSERT_EQ(node.input_size(), 1);
4308       EXPECT_EQ(node.input(0), "x");
4309 
4310       auto op_names = node.attr().at("op_names").list().s();
4311       ASSERT_EQ(op_names.size(), 3);
4312       EXPECT_EQ(op_names[0], "Sqrt");
4313       EXPECT_EQ(op_names[1], "Log");
4314       EXPECT_EQ(op_names[2], "Relu");
4315       ++required_node_count;
4316     }
4317   }
4318   EXPECT_EQ(required_node_count, 2);
4319 
4320   auto tensors = EvaluateNodes(output, item.fetch);
4321   ASSERT_EQ(tensors.size(), 1);
4322   test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4323 }
4324 
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)4325 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
4326   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4327   auto a_in =
4328       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4329   auto b_in =
4330       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4331   auto c_in =
4332       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4333   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4334                                        PartialTensorShape({-1, -1}));
4335   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4336                                        PartialTensorShape({-1, -1}));
4337   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4338                                        PartialTensorShape({-1, -1}));
4339   // stacked = tf.stack((a, b, c), axis=1).
4340   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4341   auto stacked =
4342       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4343                  ops::Stack::Axis(1));
4344   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4345   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4346   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4347   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4348   auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
4349   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4350   auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
4351   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4352   auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
4353   auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
4354   auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
4355 
4356   // stacked[:, 0]
4357   using SS = ops::StridedSlice;
4358   auto pa_slice = ops::Identity(
4359       s.WithOpName("pa_slice_out"),
4360       SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
4361          SS::BeginMask(0b0101)  // 5
4362              .EllipsisMask(0)
4363              .EndMask(0b0101)  // 5
4364              .NewAxisMask(0)
4365              .ShrinkAxisMask(0b0010)));  // 2
4366 
4367   // stacked[:, 1]
4368   auto pb_slice = ops::Identity(
4369       s.WithOpName("pb_slice_out"),
4370       SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
4371          SS::BeginMask(0b0101)  // 5
4372              .EllipsisMask(0)
4373              .EndMask(0b0101)  // 5
4374              .NewAxisMask(0)
4375              .ShrinkAxisMask(0b0010)));  // 2
4376 
4377   // stacked[:, 2]
4378   auto pc_slice = ops::Identity(
4379       s.WithOpName("pc_slice_out"),
4380       SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
4381          SS::BeginMask(0b0101)  // 5
4382              .EllipsisMask(0)
4383              .EndMask(0b0101)  // 5
4384              .NewAxisMask(0)
4385              .ShrinkAxisMask(0b0010)));  // 2
4386 
4387   // stacked[:, 0:1, :]
4388   auto pa_slice_01 = ops::Identity(
4389       s.WithOpName("pa_slice_01_out"),
4390       SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
4391          SS::BeginMask(0b0101)  // 5
4392              .EllipsisMask(0)
4393              .EndMask(0b0101)  // 5
4394              .NewAxisMask(0)
4395              .ShrinkAxisMask(0)));
4396 
4397   // stacked[:, :1, :]
4398   auto pa_slice_to1 = ops::Identity(
4399       s.WithOpName("pa_slice_to1_out"),
4400       SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
4401          SS::BeginMask(0b0111)  // 7
4402              .EllipsisMask(0)
4403              .EndMask(0b0101)  // 5
4404              .NewAxisMask(0)
4405              .ShrinkAxisMask(0)));
4406 
4407   // stacked[:, 1:2, :]
4408   auto pb_slice_12 = ops::Identity(
4409       s.WithOpName("pb_slice_12_out"),
4410       SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
4411          SS::BeginMask(0b0101)  // 5
4412              .EllipsisMask(0)
4413              .EndMask(0b0101)  // 5
4414              .NewAxisMask(0)
4415              .ShrinkAxisMask(0)));
4416 
4417   // stacked[:, 2:, :].
4418   auto pc_slice_2to = ops::Identity(
4419       s.WithOpName("pc_slice_2to_out"),
4420       SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
4421          SS::BeginMask(0b0101)  // 5
4422              .EllipsisMask(0)
4423              .EndMask(0b0111)  // 7
4424              .NewAxisMask(0)
4425              .ShrinkAxisMask(0)));
4426 
4427   GrapplerItem item;
4428   item.fetch = {"a",
4429                 "b",
4430                 "c",
4431                 "pa_slice_out",
4432                 "pb_slice_out",
4433                 "pc_slice_out",
4434                 "expanded_a",
4435                 "expanded_b",
4436                 "expanded_c",
4437                 "pa_slice_01_out",
4438                 "pa_slice_to1_out",
4439                 "pb_slice_12_out",
4440                 "pc_slice_2to_out"};
4441   enum FetchItem {
4442     fA,
4443     fB,
4444     fC,
4445     fASliceOut,
4446     fBSliceOut,
4447     fCSliceOut,
4448     fExpandedA,
4449     fExpandedB,
4450     fExpandedC,
4451     fASlice01Out,
4452     fASliceTo1Out,
4453     fBSlice12Out,
4454     fCSlice2ToOut,
4455   };
4456   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4457   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4458 
4459   // stacked[:, 0, :] == a.
4460   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4461                                  tensors_expected[fA]);
4462   // stacked[:, 1, :] == b.
4463   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4464                                  tensors_expected[fB]);
4465   // stacked[:, 2, :] == c.
4466   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4467                                  tensors_expected[fC]);
4468 
4469   // stacked[:, 0:1, :] == expand_dims(a, 1).
4470   test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
4471                                  tensors_expected[fExpandedA]);
4472 
4473   // stacked[:, :1, :] == expand_dims(a, 1).
4474   test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
4475                                  tensors_expected[fExpandedA]);
4476 
4477   // stacked[:, 1:2, :] == expand_dims(b, 1).
4478   test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
4479                                  tensors_expected[fExpandedB]);
4480   // stacked[:, 2:, :] == expand_dims(c, 1).
4481   test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
4482                                  tensors_expected[fExpandedC]);
4483 
4484   GraphDef output;
4485   ArithmeticOptimizer optimizer;
4486   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4487   OptimizeAndPrune(&optimizer, &item, &output);
4488 
4489   for (const auto& node : output.node()) {
4490     if (node.name() == "pa_slice_out") {
4491       ASSERT_EQ(node.input_size(), 1);
4492       EXPECT_EQ(node.input(0), "a");
4493     } else if (node.name() == "pb_slice_out") {
4494       ASSERT_EQ(node.input_size(), 1);
4495       EXPECT_EQ(node.input(0), "b");
4496     } else if (node.name() == "pc_slice_out") {
4497       ASSERT_EQ(node.input_size(), 1);
4498       EXPECT_EQ(node.input(0), "c");
4499     } else if (str_util::EndsWith(node.name(), "_out")) {
4500       ASSERT_EQ(node.input_size(), 1);
4501       EXPECT_EQ(
4502           absl::StrCat(node.input(0), "_out"),
4503           absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4504                        node.name()));
4505     }
4506   }
4507 
4508   auto tensors = EvaluateNodes(output, item.fetch);
4509 
4510   // stacked[:, 0, :] == a.
4511   test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4512 
4513   // stacked[:, 1, :] == b.
4514   test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4515   // stacked[:, 2, :] == c.
4516   test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4517 
4518   // stacked[:, 0:1, :] == expand_dims(a, 1).
4519   test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4520                                  tensors_expected[fExpandedA]);
4521 
4522   // stacked[:, :1, :] == expand_dims(a, 1).
4523   test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4524                                  tensors_expected[fExpandedA]);
4525 
4526   // stacked[:, 1:2, :] == expand_dims(b, 1).
4527   test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4528                                  tensors_expected[fExpandedB]);
4529   // stacked[:, 2:, :] == expand_dims(c, 1).
4530   test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4531                                  tensors_expected[fExpandedC]);
4532 }
4533 
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4534 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4535   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4536   auto a_in =
4537       ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4538   auto b_in =
4539       ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4540   auto c_in =
4541       ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4542   auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4543                                        PartialTensorShape({-1, -1}));
4544   auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4545                                        PartialTensorShape({-1, -1}));
4546   auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4547                                        PartialTensorShape({-1, -1}));
4548   // stacked = tf.stack((a, b, c), axis=1).
4549   // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4550   auto stacked =
4551       ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4552                  ops::Stack::Axis(1));
4553   auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4554   auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4555   auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4556   auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4557   auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4558   auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4559   auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4560 
4561   // stacked[:, 0:1, :]
4562   auto pa_slice = ops::Identity(
4563       s.WithOpName("pa_slice_out"),
4564       ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4565 
4566   // stacked[:, 1:2, :]
4567   auto pb_slice = ops::Identity(
4568       s.WithOpName("pb_slice_out"),
4569       ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4570 
4571   // stacked[:, 2:3, :]
4572   auto pc_slice = ops::Identity(
4573       s.WithOpName("pc_slice_out"),
4574       ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4575 
4576   GrapplerItem item;
4577   item.fetch = {"a",
4578                 "b",
4579                 "c",
4580                 "pa_slice_out",
4581                 "pb_slice_out",
4582                 "pc_slice_out",
4583                 "expanded_a",
4584                 "expanded_b",
4585                 "expanded_c"};
4586   enum FetchItem {
4587     fA,
4588     fB,
4589     fC,
4590     fASliceOut,
4591     fBSliceOut,
4592     fCSliceOut,
4593     fExpandedA,
4594     fExpandedB,
4595     fExpandedC,
4596   };
4597   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4598   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4599 
4600   // stacked[:, 0:1, :] == a.
4601   test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4602                                  tensors_expected[fExpandedA]);
4603   // stacked[:, 1:2, :] == b.
4604   test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4605                                  tensors_expected[fExpandedB]);
4606   // stacked[:, 2:3, :] == c.
4607   test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4608                                  tensors_expected[fExpandedC]);
4609 
4610   GraphDef output;
4611   ArithmeticOptimizer optimizer;
4612   EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4613   OptimizeAndPrune(&optimizer, &item, &output);
4614 
4615   const string kExpandDimsNamePrefix(
4616       "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4617 
4618   for (const auto& node : output.node()) {
4619     if (node.name() == "pa_slice_out") {
4620       ASSERT_EQ(node.input_size(), 1);
4621       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4622     } else if (node.name() == "pb_slice_out") {
4623       ASSERT_EQ(node.input_size(), 1);
4624       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4625     } else if (node.name() == "pc_slice_out") {
4626       ASSERT_EQ(node.input_size(), 1);
4627       EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4628     } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4629       EXPECT_EQ(node.op(), "ExpandDims");
4630       // The input is "a", "b", or "c", as appropriate.
4631       EXPECT_EQ(node.input(0),
4632                 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4633     }
4634   }
4635 
4636   auto tensors = EvaluateNodes(output, item.fetch);
4637 
4638   // stacked[:, 0:1, :] == a.
4639   test::ExpectTensorEqual<float>(tensors[fASliceOut],
4640                                  tensors_expected[fExpandedA]);
4641 
4642   // stacked[:, 1:2, :] == b.
4643   test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4644                                  tensors_expected[fExpandedB]);
4645   // stacked[:, 2:3, :] == c.
4646   test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4647                                  tensors_expected[fExpandedC]);
4648 }
4649 
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4650 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4651   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4652   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4653   Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4654   Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4655   Output id = ops::Identity(s.WithOpName("id"), add);
4656 
4657   GrapplerItem item;
4658   TF_CHECK_OK(s.ToGraphDef(&item.graph));
4659   item.fetch = {"id"};
4660   auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4661   ASSERT_EQ(tensors_expected.size(), 1);
4662 
4663   GraphDef output;
4664   ArithmeticOptimizer optimizer;
4665   EnableOnlySimplifyAggregation(&optimizer);
4666   OptimizeAndPrune(&optimizer, &item, &output);
4667 
4668   // Extra node created for multiplier.
4669   EXPECT_EQ(output.node_size(), 5);
4670 
4671   auto tensors = EvaluateNodes(output, item.fetch);
4672   ASSERT_EQ(tensors.size(), 1);
4673   test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4674 }
4675 
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4676 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4677   for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4678     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4679     Output embeddings = ops::Const(s.WithOpName("embeddings"),
4680                                    {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4681     Output segment_ids =
4682         ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4683     Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4684     auto unique = ops::Unique(s.WithOpName("unique"), indices,
4685                               /*attrs=*/{unique_idx_type});
4686     Output ids = unique.y;
4687     Output idx = unique.idx;
4688     Output gathered_rows =
4689         ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4690     Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4691                                           idx, segment_ids);
4692     Output id = ops::Identity(s.WithOpName("id"), result);
4693 
4694     GrapplerItem item;
4695     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4696     item.fetch = {"id"};
4697     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4698     ASSERT_EQ(tensors_expected.size(), 1);
4699 
4700     GraphDef output;
4701     ArithmeticOptimizer optimizer;
4702     EnableOnlySimplifyEmbeddingLookup(&optimizer);
4703     OptimizeAndPrune(&optimizer, &item, &output);
4704 
4705     for (const auto& node : output.node()) {
4706       if (node.name() == "result") {
4707         EXPECT_EQ(node.input(0), "embeddings");
4708         EXPECT_EQ(node.input(1), "indices");
4709       }
4710       EXPECT_NE(node.op(), "Unique");
4711       EXPECT_NE(node.op(), "Gather");
4712     }
4713 
4714     auto tensors = EvaluateNodes(output, item.fetch);
4715     ASSERT_EQ(tensors.size(), 1);
4716     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4717   }
4718 }
4719 
TEST_F(ArithmeticOptimizerTest,SimplifyResourceEmbeddingLookup)4720 TEST_F(ArithmeticOptimizerTest, SimplifyResourceEmbeddingLookup) {
4721   for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4722     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4723     Output embeddings = ops::Const(s.WithOpName("embeddings"),
4724                                    {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4725     Output segment_ids =
4726         ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4727     Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4728     auto unique = ops::Unique(s.WithOpName("unique"), indices,
4729                               /*attrs=*/{unique_idx_type});
4730     Output ids = unique.y;
4731     Output idx = unique.idx;
4732 
4733     auto var =
4734         ops::VarHandleOp(s.WithOpName("var"), DT_FLOAT, TensorShape({2, 2}));
4735     ops::AssignVariableOp assign_op(s.WithOpName("assign_var_handle"), var,
4736                                     embeddings);
4737 
4738     Output gathered_rows = ops::ResourceGather(
4739         s.WithOpName("gathered_rows")
4740             .WithControlDependencies(std::vector<Operation>{assign_op}),
4741         var, ids, DT_FLOAT);
4742     gathered_rows.node()->AddAttr("_class", {"test_class"});
4743     Output result =
4744         ops::SparseSegmentSum(s.WithOpName("result").WithControlDependencies(
4745                                   std::vector<Operation>{assign_op}),
4746                               gathered_rows, idx, segment_ids);
4747     Output id = ops::Identity(s.WithOpName("id"), result);
4748 
4749     GrapplerItem item;
4750     item.init_ops.push_back("assign_var_handle");
4751     TF_CHECK_OK(s.ToGraphDef(&item.graph));
4752     item.fetch = {"id"};
4753     auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4754     ASSERT_EQ(tensors_expected.size(), 1);
4755 
4756     GraphDef output;
4757     ArithmeticOptimizer optimizer;
4758     EnableOnlySimplifyEmbeddingLookup(&optimizer);
4759     OptimizeAndPrune(&optimizer, &item, &output);
4760     bool read_var_node_found = false;
4761     for (const auto& node : output.node()) {
4762       if (node.name() == "result") {
4763         EXPECT_EQ(
4764             node.input(0),
4765             "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result");
4766         EXPECT_EQ(node.input(1), "indices");
4767       }
4768       if (node.op() == "ReadVariableOp") {
4769         read_var_node_found = true;
4770         EXPECT_EQ(node.attr().at("_class").list().s(0), "test_class");
4771       }
4772       EXPECT_NE(node.op(), "Unique");
4773       EXPECT_NE(node.op(), "Gather");
4774     }
4775     EXPECT_TRUE(read_var_node_found);
4776     // Add a control dependency to the ReadVar to do the AssignVar first. This
4777     // shouldn't be an issue in actual use as variables are assumed initialized
4778     // during setup.
4779     for (int i = 0; i < output.node_size(); ++i) {
4780       if (output.node(i).name() ==
4781           "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result") {
4782         output.mutable_node(i)->add_input("^assign_var_handle");
4783       }
4784     }
4785 
4786     auto tensors = EvaluateNodes(output, item.fetch);
4787     ASSERT_EQ(tensors.size(), 1);
4788     test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4789   }
4790 }
4791 
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4792 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4793   for (DataType indices_type : {DT_INT32, DT_INT64}) {
4794     for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4795       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4796       Output embeddings = ops::Const(s.WithOpName("embeddings"),
4797                                      {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4798       Output indices =
4799           ops::Cast(s.WithOpName("cast_indices"),
4800                     ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4801                     indices_type);
4802       Output segment_ids = ops::Cast(
4803           s.WithOpName("cast_segment_ids"),
4804           ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4805           segment_ids_type);
4806       Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4807                                             indices, segment_ids);
4808       Output id = ops::Identity(s.WithOpName("id"), result);
4809 
4810       GrapplerItem item;
4811       TF_CHECK_OK(s.ToGraphDef(&item.graph));
4812       item.fetch = {"id"};
4813       auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4814       ASSERT_EQ(tensors_expected.size(), 1);
4815 
4816       GraphDef output;
4817       ArithmeticOptimizer optimizer;
4818       EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4819       OptimizeAndPrune(&optimizer, &item, &output);
4820 
4821       for (const auto& node : output.node()) {
4822         if (node.name() == "result") {
4823           EXPECT_EQ(node.input(1), "indices");
4824           EXPECT_EQ(node.input(2), "segment_ids");
4825         }
4826         EXPECT_NE(node.op(), "Cast");
4827       }
4828 
4829       auto tensors = EvaluateNodes(output, item.fetch);
4830       ASSERT_EQ(tensors.size(), 1);
4831       test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4832     }
4833   }
4834 }
4835 
4836 }  // namespace grappler
4837 }  // namespace tensorflow
4838