1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <complex>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/cc/ops/array_ops.h"
24 #include "tensorflow/cc/ops/math_ops.h"
25 #include "tensorflow/cc/ops/nn_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/tensor_testutil.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
32 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
33 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/test.h"
37
38 namespace tensorflow {
39 namespace grappler {
40
41 namespace {
42
43 constexpr char kHoistFactorOptimizerDiv[] =
44 "ArithmeticOptimizer/HoistCommonFactor_Div_";
45
46 constexpr char kHoistFactorOptimizerMul[] =
47 "ArithmeticOptimizer/HoistCommonFactor_Mul_";
48
49 constexpr char kHoistFactorOptimizerAdd[] =
50 "ArithmeticOptimizer/HoistCommonFactor_AddV2_";
51
52 constexpr char kSimplifyAggregationConst[] =
53 "ArithmeticOptimizer/SimplifyAggregation_Const_";
54
55 constexpr char kSimplifyAggregationMul[] =
56 "ArithmeticOptimizer/SimplifyAggregation_Mul_";
57
58 // Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
HoistMulName(const string & name)59 string HoistMulName(const string& name) {
60 return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
61 }
62
63 // Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
HoistDivName(const string & name)64 string HoistDivName(const string& name) {
65 return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
66 }
67
68 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
HoistAddName(const string & name)69 string HoistAddName(const string& name) {
70 return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
71 }
72
73 // Optimized name of Const node by SimplifyAggregation.
AggregationConstName(const string & name)74 string AggregationConstName(const string& name) {
75 return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
76 }
77
78 // Optimized name of Mul node by SimplifyAggregation.
AggregationMulName(const string & name)79 string AggregationMulName(const string& name) {
80 return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
81 }
82
VerifyGraphsMatch(const GraphDef & original_graph,const GraphDef & optimized_graph,int line)83 void VerifyGraphsMatch(const GraphDef& original_graph,
84 const GraphDef& optimized_graph, int line) {
85 EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
86 for (int i = 0; i < original_graph.node_size(); ++i) {
87 const NodeDef& original = original_graph.node(i);
88 const NodeDef& optimized = optimized_graph.node(i);
89 EXPECT_EQ(original.name(), optimized.name()) << line;
90 EXPECT_EQ(original.op(), optimized.op()) << line;
91 EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
92 for (int j = 0; j < original.input_size(); ++j) {
93 EXPECT_EQ(original.input(j), optimized.input(j)) << line;
94 }
95 }
96 }
97 } // namespace
98
TEST_F(ArithmeticOptimizerTest,NoOp)99 TEST_F(ArithmeticOptimizerTest, NoOp) {
100 // This trivial graph is so basic there's nothing to optimize.
101 TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
102 GrapplerItem item;
103 CHECK(fake_input.NextItem(&item));
104
105 ArithmeticOptimizer optimizer;
106 GraphDef output;
107 Status status = optimizer.Optimize(nullptr, item, &output);
108 TF_EXPECT_OK(status);
109 VerifyGraphsMatch(item.graph, output, __LINE__);
110 }
111
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTile)112 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTile) {
113 // Graph from b/176172427
114 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
115 Output input =
116 ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
117 ops::Placeholder::Shape({1, 44, 1, 96, 1, 64}));
118 Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 1, 2, 1, 2, 1});
119 Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
120 Output output = ops::Identity(s.WithOpName("output"), multiply);
121
122 GrapplerItem item;
123 item.fetch = {"output"};
124 TF_CHECK_OK(s.ToGraphDef(&item.graph));
125 auto tensor =
126 GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 44, 1, 96, 1, 64}));
127 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
128 ASSERT_EQ(expected.size(), 1);
129
130 GraphDef g;
131 ArithmeticOptimizer optimizer;
132 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
133 OptimizeTwiceAndPrune(&optimizer, &item, &g);
134 EXPECT_EQ(g.node_size(), 4);
135
136 ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
137 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
138
139 NodeMap node_map(&g);
140 const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
141 const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_mul"));
142 const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
143 ASSERT_NE(t, nullptr);
144 ASSERT_NE(c, nullptr);
145 EXPECT_EQ(t->op(), "Tile");
146 ASSERT_EQ(t->input_size(), 2);
147 EXPECT_EQ(t->input(0), "input");
148 EXPECT_EQ(t->input(1), c->name());
149 EXPECT_EQ(t->attr().at("T").type(), DT_FLOAT);
150 EXPECT_EQ(t->attr().at("Tmultiples").type(), c->attr().at("dtype").type());
151
152 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
153 ASSERT_EQ(result.size(), 1);
154 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
155 }
156
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTilePreserveControl)157 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTilePreserveControl) {
158 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
159 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
160 ops::Placeholder::Shape({1, 1, 1}));
161 Output ones = ops::Const(s.WithOpName("ones").WithControlDependencies(input),
162 1.0f, {1, 2, 1});
163 Output multiply = ops::Mul(s.WithOpName("mul"), input, ones);
164 Output output = ops::Identity(s.WithOpName("output"), multiply);
165
166 GrapplerItem item;
167 item.fetch = {"output"};
168 TF_CHECK_OK(s.ToGraphDef(&item.graph));
169 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
170 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
171 ASSERT_EQ(expected.size(), 1);
172
173 GraphDef g;
174 ArithmeticOptimizer optimizer;
175 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
176 OptimizeTwiceAndPrune(&optimizer, &item, &g);
177 EXPECT_EQ(g.node_size(), 4);
178
179 ASSERT_EQ(CountOpNodes(g, "Mul"), 0);
180 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
181
182 NodeMap node_map(&g);
183 const string p = "ArithmeticOptimizer/ReplaceMulWithBroadcastByTile";
184 const NodeDef* c = node_map.GetNode(absl::StrCat(p, "_", "Const_mul"));
185 ASSERT_NE(c, nullptr);
186 ASSERT_EQ(c->input_size(), 1);
187 EXPECT_TRUE(IsControlInput(c->input(0)));
188 EXPECT_EQ(c->input(0), "^input");
189 }
190
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNoBroadcast)191 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNoBroadcast) {
192 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
193 Output input =
194 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 2, 1}));
195 Output ones = ops::Const(s.WithOpName("ones"), 1.0f, {1, 2, 1});
196 Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
197 Output output = ops::Identity(s.WithOpName("output"), multiply);
198
199 GrapplerItem item;
200 item.fetch = {"output"};
201 TF_CHECK_OK(s.ToGraphDef(&item.graph));
202 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
203 auto expected =
204 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
205 ASSERT_EQ(expected.size(), 1);
206
207 GraphDef g;
208 ArithmeticOptimizer optimizer;
209 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
210 OptimizeTwiceAndPrune(&optimizer, &item, &g);
211 EXPECT_EQ(g.node_size(), 4);
212
213 VerifyGraphsMatch(item.graph, g, __LINE__);
214
215 auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
216 ASSERT_EQ(result.size(), 1);
217 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
218 }
219
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotConst)220 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotConst) {
221 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
222 Output input1 = ops::Placeholder(s.WithOpName("input1"), DT_FLOAT,
223 ops::Placeholder::Shape({1, 1, 1}));
224 Output input2 = ops::Placeholder(s.WithOpName("input2"), DT_FLOAT,
225 ops::Placeholder::Shape({1, 2, 1}));
226 Output multiply = ops::Mul(s.WithOpName("multiply"), input1, input2);
227 Output output = ops::Identity(s.WithOpName("output"), multiply);
228
229 GrapplerItem item;
230 item.fetch = {"output"};
231 TF_CHECK_OK(s.ToGraphDef(&item.graph));
232 auto tensor1 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
233 auto tensor2 = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 1}));
234 auto expected = EvaluateNodes(item.graph, item.fetch,
235 {{"input1", tensor1}, {"input2", tensor2}});
236 ASSERT_EQ(expected.size(), 1);
237
238 GraphDef g;
239 ArithmeticOptimizer optimizer;
240 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
241 OptimizeTwiceAndPrune(&optimizer, &item, &g);
242 EXPECT_EQ(g.node_size(), 4);
243
244 VerifyGraphsMatch(item.graph, g, __LINE__);
245
246 auto result = EvaluateNodes(item.graph, item.fetch,
247 {{"input1", tensor1}, {"input2", tensor2}});
248 ASSERT_EQ(result.size(), 1);
249 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
250 }
251
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithBroadcastByTileNotOnes)252 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithBroadcastByTileNotOnes) {
253 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
254 Output input =
255 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 1, 1}));
256 Output ones = ops::Const(s.WithOpName("ones"), 2.0f, {1, 2, 1});
257 Output multiply = ops::Mul(s.WithOpName("multiply"), input, ones);
258 Output output = ops::Identity(s.WithOpName("output"), multiply);
259
260 GrapplerItem item;
261 item.fetch = {"output"};
262 TF_CHECK_OK(s.ToGraphDef(&item.graph));
263 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 1, 1}));
264 auto expected =
265 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", tensor}});
266 ASSERT_EQ(expected.size(), 1);
267
268 GraphDef g;
269 ArithmeticOptimizer optimizer;
270 EnableOnlyReplaceMulWithBroadcastByTile(&optimizer);
271 OptimizeTwiceAndPrune(&optimizer, &item, &g);
272 EXPECT_EQ(g.node_size(), 4);
273
274 VerifyGraphsMatch(item.graph, g, __LINE__);
275
276 auto result = EvaluateNodes(g, item.fetch, {{"Placeholder", tensor}});
277 ASSERT_EQ(result.size(), 1);
278 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
279 }
280
TEST_F(ArithmeticOptimizerTest,ReduceUpsamplingDims)281 TEST_F(ArithmeticOptimizerTest, ReduceUpsamplingDims) {
282 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
283 Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
284 ops::Placeholder::Shape({1, 22, 48, 64}));
285 Output reshape_a = ops::Reshape(
286 s.WithOpName("reshape_a"), input,
287 ops::Const(s.WithOpName("shape_a"), {1, 22, 1, 48, 1, 64}, {6}));
288 Output tile =
289 ops::Tile(s.WithOpName("tile"), reshape_a,
290 ops::Const(s.WithOpName("multiples"), {1, 1, 2, 1, 2, 1}, {6}));
291 Output reshape_b =
292 ops::Reshape(s.WithOpName("reshape_b"), tile,
293 ops::Const(s.WithOpName("shape_b"), {1, 44, 96, 64}));
294 Output output = ops::Identity(s.WithOpName("output"), reshape_b);
295
296 GrapplerItem item;
297 item.fetch = {"output"};
298 TF_CHECK_OK(s.ToGraphDef(&item.graph));
299 auto tensor = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 22, 48, 64}));
300 auto expected = EvaluateNodes(item.graph, item.fetch, {{"input", tensor}});
301 ASSERT_EQ(expected.size(), 1);
302
303 GraphDef g;
304 ArithmeticOptimizer optimizer;
305 EnableOnlyReduceUpsamplingDims(&optimizer);
306 OptimizeTwiceAndPrune(&optimizer, &item, &g);
307 EXPECT_EQ(g.node_size(), 8);
308
309 ASSERT_EQ(CountOpNodes(g, "Tile"), 1);
310 ASSERT_EQ(CountOpNodes(g, "Reshape"), 2);
311 ASSERT_EQ(CountOpNodes(g, "Const"), 3);
312
313 NodeMap node_map(&g);
314 const string p = "ArithmeticOptimizer/ReduceUpsamplingDims";
315 const NodeDef* ra =
316 node_map.GetNode(absl::StrCat(p, "_", "Reshape_reshape_b"));
317 const NodeDef* rb = node_map.GetNode("reshape_b");
318 const NodeDef* t = node_map.GetNode(absl::StrCat(p, "_", "Tile_reshape_b"));
319 ASSERT_NE(ra, nullptr);
320 ASSERT_NE(rb, nullptr);
321 ASSERT_NE(t, nullptr);
322
323 ASSERT_EQ(rb->input_size(), 2);
324 EXPECT_EQ(rb->input(0), t->name());
325 ASSERT_EQ(t->input_size(), 2);
326 EXPECT_EQ(t->input(0), ra->name());
327 ASSERT_EQ(ra->input_size(), 2);
328 EXPECT_EQ(ra->input(0), "input");
329
330 {
331 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
332 ASSERT_EQ(result.size(), 1);
333 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
334 }
335
336 // Check to make sure the first reshape is removed
337 EnableOnlyRemoveRedundantReshape(&optimizer);
338 OptimizeTwiceAndPrune(&optimizer, &item, &g);
339 EXPECT_EQ(g.node_size(), 6);
340
341 {
342 auto result = EvaluateNodes(g, item.fetch, {{"input", tensor}});
343 ASSERT_EQ(result.size(), 1);
344 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
345 }
346 }
347
TEST_F(ArithmeticOptimizerTest,ReplaceMulWithSquare)348 TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
349 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
350 Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
351 Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
352 Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
353 Output mul_no_nan = ops::MulNoNan(s.WithOpName("mul_no_nan"), d, d);
354 Output id = ops::Identity(s.WithOpName("id"), mul);
355 Output id2 = ops::Identity(s.WithOpName("id2"), mul_no_nan);
356
357 GrapplerItem item;
358 item.fetch = {"id", "id2"};
359 TF_CHECK_OK(s.ToGraphDef(&item.graph));
360 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
361 ASSERT_EQ(tensors_expected.size(), 2);
362
363 GraphDef output;
364 ArithmeticOptimizer optimizer;
365 EnableOnlyReplaceMulWithSquare(&optimizer);
366 OptimizeAndPrune(&optimizer, &item, &output);
367
368 EXPECT_EQ(output.node_size(), 6);
369
370 NodeMap node_map(&output);
371 const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
372 const NodeDef* square_node = node_map.GetNode(absl::StrCat(p, "_", "mul"));
373
374 ASSERT_NE(square_node, nullptr);
375 EXPECT_EQ(square_node->op(), "Square");
376 ASSERT_EQ(square_node->input_size(), 2);
377 EXPECT_EQ(square_node->input(0), "c");
378 EXPECT_EQ(square_node->input(1), "^d");
379
380 const NodeDef* square_node2 =
381 node_map.GetNode(absl::StrCat(p, "_", "mul_no_nan"));
382 ASSERT_NE(square_node2, nullptr);
383 EXPECT_EQ(square_node2->op(), "Square");
384 ASSERT_EQ(square_node2->input_size(), 1);
385 EXPECT_EQ(square_node2->input(0), "d");
386
387 auto tensors = EvaluateNodes(output, item.fetch);
388 ASSERT_EQ(tensors.size(), 2);
389 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
390 }
391
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshape)392 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshape) {
393 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
394 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
395 ops::Placeholder::Shape({3, 5, 7, 11}));
396 // Stack creates Pack nodes
397 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
398 Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
399 Output o = ops::Identity(s.WithOpName("output"), c);
400
401 GrapplerItem item;
402 item.fetch = {"output"};
403 TF_CHECK_OK(s.ToGraphDef(&item.graph));
404 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
405 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
406 ASSERT_EQ(expected.size(), 1);
407
408 GraphDef g;
409 ArithmeticOptimizer optimizer;
410 EnableOnlyReplacePackWithTileReshape(&optimizer);
411 OptimizeAndPrune(&optimizer, &item, &g);
412
413 EXPECT_EQ(g.node_size(), 6);
414 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
415 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
416 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
417 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
418
419 NodeMap node_map(&g);
420 const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
421 const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
422 const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
423 const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
424 const NodeDef* r_node = node_map.GetNode(absl::StrCat(p, "_", "Reshape_c"));
425 const NodeDef* a_node = node_map.GetNode("a");
426 ASSERT_NE(t_node, nullptr);
427 ASSERT_NE(c_node, nullptr);
428 ASSERT_NE(s_node, nullptr);
429 ASSERT_NE(r_node, nullptr);
430 ASSERT_NE(a_node, nullptr);
431
432 EXPECT_EQ(c_node->op(), "Const");
433 EXPECT_EQ(s_node->op(), "Const");
434
435 // Check Reshape properties
436 ASSERT_EQ(r_node->input_size(), 2);
437 EXPECT_EQ(r_node->op(), "Reshape");
438 EXPECT_EQ(r_node->input(0), t_node->name());
439 EXPECT_EQ(r_node->input(1), s_node->name());
440
441 // Check Tile properties
442 ASSERT_EQ(t_node->input_size(), 2);
443 EXPECT_EQ(t_node->op(), "Tile");
444 EXPECT_EQ(t_node->input(0), a_node->name());
445 EXPECT_EQ(t_node->input(1), c_node->name());
446 EXPECT_EQ(t_node->attr().at("T").type(), DT_FLOAT);
447 EXPECT_EQ(t_node->attr().at("Tmultiples").type(),
448 c_node->attr().at("dtype").type());
449
450 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
451 ASSERT_EQ(result.size(), 1);
452 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
453 }
454
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeControlDeps)455 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeControlDeps) {
456 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
457 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
458 ops::Placeholder::Shape({3, 5, 7, 11}));
459
460 Output x = ops::Identity(s.WithOpName("x"), a);
461 Output y = ops::Identity(s.WithOpName("y"), a);
462
463 Output b = ops::Stack(s.WithOpName("b").WithControlDependencies(x), {a, a},
464 ops::Stack::Axis(3));
465 Output c = ops::Stack(s.WithOpName("c").WithControlDependencies(y), {b, b},
466 ops::Stack::Axis(2));
467 Output o = ops::Identity(s.WithOpName("output"), c);
468
469 GrapplerItem item;
470 item.fetch = {"output"};
471 item.keep_ops = {"x", "y"};
472 TF_CHECK_OK(s.ToGraphDef(&item.graph));
473 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
474 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
475 ASSERT_EQ(expected.size(), 1);
476
477 GraphDef g;
478 ArithmeticOptimizer optimizer;
479 EnableOnlyReplacePackWithTileReshape(&optimizer);
480 OptimizeAndPrune(&optimizer, &item, &g);
481
482 EXPECT_EQ(g.node_size(), 8);
483 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
484 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
485 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
486 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
487 EXPECT_EQ(CountOpNodes(g, "Identity"), 3);
488
489 NodeMap node_map(&g);
490 const string p = "ArithmeticOptimizer/ReplacePackWithTileReshape";
491 const NodeDef* t_node = node_map.GetNode(absl::StrCat(p, "_", "Tile_c"));
492 const NodeDef* c_node = node_map.GetNode(absl::StrCat(p, "_", "Multiples_c"));
493 const NodeDef* s_node = node_map.GetNode(absl::StrCat(p, "_", "Shape_c"));
494 const NodeDef* a_node = node_map.GetNode("a");
495 ASSERT_NE(t_node, nullptr);
496 ASSERT_NE(c_node, nullptr);
497 ASSERT_NE(s_node, nullptr);
498 ASSERT_NE(a_node, nullptr);
499
500 ASSERT_EQ(t_node->input_size(), 4);
501 EXPECT_EQ(t_node->op(), "Tile");
502 EXPECT_EQ(t_node->input(0), a_node->name());
503 EXPECT_EQ(t_node->input(1), c_node->name());
504 EXPECT_EQ(t_node->input(2), "^y");
505 EXPECT_EQ(t_node->input(3), "^x");
506
507 ASSERT_EQ(c_node->input_size(), 1);
508 EXPECT_EQ(c_node->input(0), "^a");
509
510 ASSERT_EQ(s_node->input_size(), 1);
511 ASSERT_EQ(s_node->input(0), "^a");
512
513 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
514 ASSERT_EQ(result.size(), 1);
515 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
516 }
517
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileRemoveReshape)518 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileRemoveReshape) {
519 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
520 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
521 ops::Placeholder::Shape({3, 5, 7, 11}));
522 // Stack creates Pack nodes
523 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(3));
524 Output c = ops::Stack(s.WithOpName("c"), {b, b}, ops::Stack::Axis(2));
525 Output r =
526 ops::Reshape(s.WithOpName("r"), c, ops::Const(s, {3, 10, 14, 11}, {4}));
527 Output o = ops::Identity(s.WithOpName("output"), r);
528
529 GrapplerItem item;
530 item.fetch = {"output"};
531 TF_CHECK_OK(s.ToGraphDef(&item.graph));
532 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7, 11}));
533 auto expected = EvaluateNodes(item.graph, item.fetch, {{"a", a_t}});
534 ASSERT_EQ(expected.size(), 1);
535
536 GraphDef g;
537 ArithmeticOptimizer optimizer;
538 EnableOnlyReplacePackWithTileReshape(&optimizer);
539 OptimizeAndPrune(&optimizer, &item, &g);
540
541 EXPECT_EQ(g.node_size(), 8);
542 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
543 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
544 EXPECT_EQ(CountOpNodes(g, "Const"), 3);
545 EXPECT_EQ(CountOpNodes(g, "Reshape"), 2);
546
547 EnableOnlyRemoveRedundantReshape(&optimizer);
548 OptimizeAndPrune(&optimizer, &item, &g);
549
550 EXPECT_EQ(g.node_size(), 6);
551 EXPECT_EQ(CountOpNodes(g, "Pack"), 0);
552 EXPECT_EQ(CountOpNodes(g, "Tile"), 1);
553 EXPECT_EQ(CountOpNodes(g, "Const"), 2);
554 EXPECT_EQ(CountOpNodes(g, "Reshape"), 1);
555
556 auto result = EvaluateNodes(g, item.fetch, {{"a", a_t}});
557 ASSERT_EQ(result.size(), 1);
558 test::ExpectTensorNear<float>(result[0], expected[0], 1e-6);
559 }
560
TEST_F(ArithmeticOptimizerTest,ReplacePackWithTileReshapeOutOfRange)561 TEST_F(ArithmeticOptimizerTest, ReplacePackWithTileReshapeOutOfRange) {
562 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
563 Output a = ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
564 ops::Placeholder::Shape({3, 5, 7, 11}));
565 // Stack creates Pack nodes
566 Output b = ops::Stack(s.WithOpName("b"), {a, a}, ops::Stack::Axis(4));
567 Output o = ops::Identity(s.WithOpName("output"), b);
568
569 GrapplerItem item;
570 item.fetch = {"output"};
571 TF_CHECK_OK(s.ToGraphDef(&item.graph));
572
573 GraphDef g;
574 ArithmeticOptimizer optimizer;
575 EnableOnlyReplacePackWithTileReshape(&optimizer);
576 OptimizeAndPrune(&optimizer, &item, &g);
577
578 VerifyGraphsMatch(item.graph, g, __LINE__);
579 }
580
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAdjacentNodes)581 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAdjacentNodes) {
582 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
583
584 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
585 auto neg1 = ops::Neg(s.WithOpName("neg1"), c);
586 auto neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
587 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
588 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
589 auto id = ops::Identity(s.WithOpName("id"), recip2);
590
591 GrapplerItem item;
592 item.fetch = {"id"};
593 TF_CHECK_OK(s.ToGraphDef(&item.graph));
594 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
595 ASSERT_EQ(tensors_expected.size(), 1);
596
597 GraphDef output;
598 ArithmeticOptimizer optimizer;
599 EnableOnlyRemoveInvolution(&optimizer);
600 OptimizeAndPrune(&optimizer, &item, &output);
601
602 // Negation and Reciprocal nodes cancelled each other.
603 ASSERT_EQ(output.node_size(), 2);
604 EXPECT_EQ(output.node(1).name(), "id");
605 ASSERT_EQ(output.node(1).input_size(), 1);
606 EXPECT_EQ(output.node(1).input(0), "c");
607
608 auto tensors = EvaluateNodes(output, item.fetch);
609 ASSERT_EQ(tensors.size(), 1);
610 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
611 }
612
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionAroundValuePreservingChain)613 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionAroundValuePreservingChain) {
614 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
615
616 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
617 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
618 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
619 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
620 auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
621 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
622
623 std::vector<string> fetch = {"id2"};
624
625 GrapplerItem item;
626 item.fetch = fetch;
627 TF_CHECK_OK(s.ToGraphDef(&item.graph));
628 auto tensors_expected = EvaluateNodes(item.graph, fetch);
629 ASSERT_EQ(tensors_expected.size(), 1);
630
631 GraphDef output;
632 ArithmeticOptimizer optimizer;
633 EnableOnlyRemoveInvolution(&optimizer);
634 OptimizeTwiceAndPrune(&optimizer, &item, &output);
635
636 // Check that Reciprocal nodes were removed from the graph.
637 EXPECT_EQ(output.node_size(), 3);
638
639 // And const directly flows into squeeze.
640 int found = 0;
641 for (const NodeDef& node : output.node()) {
642 if (node.name() == "squeeze") {
643 ASSERT_EQ(node.input_size(), 1);
644 EXPECT_EQ(node.input(0), "c");
645 found++;
646 } else if (node.name() == "id2") {
647 ASSERT_EQ(node.input_size(), 1);
648 EXPECT_EQ(node.input(0), "squeeze");
649 found++;
650 }
651 }
652 EXPECT_EQ(found, 2);
653
654 auto tensors = EvaluateNodes(output, fetch);
655 ASSERT_EQ(tensors.size(), 1);
656 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
657 }
658
TEST_F(ArithmeticOptimizerTest,RemoveInvolutionSkipControlDependencies)659 TEST_F(ArithmeticOptimizerTest, RemoveInvolutionSkipControlDependencies) {
660 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
661
662 auto c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
663 auto recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
664 auto id1 = ops::Identity(s.WithOpName("id1"), recip1);
665 auto squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
666 auto recip2 = ops::Reciprocal(
667 s.WithOpName("recip2").WithControlDependencies(squeeze), c);
668 auto id2 = ops::Identity(s.WithOpName("id2"), recip2);
669
670 std::vector<string> fetch = {"id2"};
671
672 GrapplerItem item;
673 item.fetch = fetch;
674 TF_CHECK_OK(s.ToGraphDef(&item.graph));
675
676 auto tensors_expected = EvaluateNodes(item.graph, fetch);
677 ASSERT_EQ(tensors_expected.size(), 1);
678
679 GraphDef output;
680 ArithmeticOptimizer optimizer;
681 EnableOnlyRemoveInvolution(&optimizer);
682 OptimizeTwice(&optimizer, &item, &output); // do not prune in this test
683
684 // The optimizer should be a noop.
685 VerifyGraphsMatch(item.graph, output, __LINE__);
686
687 auto tensors = EvaluateNodes(output, fetch);
688 ASSERT_EQ(tensors.size(), 1);
689 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
690 }
691
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimple)692 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
693 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
694 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
695 Output add = ops::Add(s.WithOpName("add"), x, x);
696 Output id = ops::Identity(s.WithOpName("id"), add);
697
698 GrapplerItem item;
699 item.fetch = {"id"};
700 TF_CHECK_OK(s.ToGraphDef(&item.graph));
701
702 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
703 ASSERT_EQ(tensors_expected.size(), 1);
704
705 ArithmeticOptimizer optimizer;
706 GraphDef output;
707 OptimizeTwice(&optimizer, &item, &output);
708 NodeMap node_map(&output);
709
710 EXPECT_EQ(output.node_size(), 5);
711
712 const string optimized_const_name = AggregationConstName("add");
713 const string optimized_mul_name = AggregationMulName("add");
714
715 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
716 ASSERT_NE(new_const, nullptr);
717 ASSERT_EQ(new_const->input_size(), 1);
718 EXPECT_EQ(new_const->input(0), "^x");
719 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
720 string("\0\0\0@", 4));
721
722 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
723 ASSERT_NE(new_mul, nullptr);
724 ASSERT_EQ(new_mul->input_size(), 2);
725 EXPECT_EQ(new_mul->input(0), optimized_const_name);
726 EXPECT_EQ(new_mul->input(1), "x");
727
728 const NodeDef* new_id = node_map.GetNode("id");
729 ASSERT_NE(new_id, nullptr);
730 ASSERT_EQ(new_id->input_size(), 1);
731 EXPECT_EQ(new_id->input(0), optimized_mul_name);
732
733 auto tensors = EvaluateNodes(output, item.fetch);
734 ASSERT_EQ(tensors.size(), 1);
735 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
736 }
737
TEST_F(ArithmeticOptimizerTest,TrivialSumsSimpleWithControlDep)738 TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
739 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
740 Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
741 Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
742 Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
743 Output id = ops::Identity(s.WithOpName("id"), add);
744
745 GrapplerItem item;
746 TF_CHECK_OK(s.ToGraphDef(&item.graph));
747
748 std::vector<string> fetch = {"id"};
749 auto tensors_expected = EvaluateNodes(item.graph, fetch);
750 ASSERT_EQ(tensors_expected.size(), 1);
751
752 ArithmeticOptimizer optimizer;
753 GraphDef output;
754 OptimizeTwice(&optimizer, &item, &output);
755 NodeMap node_map(&output);
756
757 EXPECT_EQ(output.node_size(), 6);
758
759 const string optimized_const_name = AggregationConstName("add");
760 const string optimized_mul_name = AggregationMulName("add");
761
762 const NodeDef* new_const = node_map.GetNode(optimized_const_name);
763 ASSERT_NE(new_const, nullptr);
764 ASSERT_EQ(new_const->input_size(), 1);
765 EXPECT_EQ(new_const->input(0), "^x");
766 EXPECT_EQ(new_const->attr().at("value").tensor().tensor_content(),
767 string("\0\0\0@", 4));
768
769 const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
770 ASSERT_NE(new_mul, nullptr);
771 ASSERT_EQ(new_mul->input_size(), 3);
772 EXPECT_EQ(new_mul->input(0), optimized_const_name);
773 EXPECT_EQ(new_mul->input(1), "x");
774 EXPECT_EQ(new_mul->input(2), "^y");
775
776 const NodeDef* new_id = node_map.GetNode("id");
777 ASSERT_NE(new_id, nullptr);
778 ASSERT_EQ(new_id->input_size(), 1);
779 EXPECT_EQ(new_id->input(0), optimized_mul_name);
780
781 auto tensors = EvaluateNodes(output, fetch);
782 ASSERT_EQ(tensors.size(), 1);
783 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
784 }
785
TEST_F(ArithmeticOptimizerTest,TrivialSumsRepeatedAdd)786 TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
787 // Test case from b/69059093.
788 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
789 Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
790 Output add = ops::Add(s.WithOpName("Add"), p, p);
791 Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
792 Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
793 Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
794 Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
795 Output id = ops::Identity(s.WithOpName("id"), add6);
796
797 GrapplerItem item;
798 item.fetch = {"id"};
799 TF_CHECK_OK(s.ToGraphDef(&item.graph));
800
801 const std::vector<string> devices{
802 "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
803 "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
804 };
805 for (int i = 0; i < item.graph.node_size(); ++i) {
806 item.graph.mutable_node(i)->set_device(devices[i]);
807 }
808
809 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
810 DisableAddToAddNCombining(&optimizer);
811
812 GraphDef output;
813 DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
814
815 // We expect the following rewrite(s) to occur:
816 //
817 // Mul(p,
818 // Add_6(Add_4(Const(2), Const(2)),
819 // Add_5(Const(2), Const(2)))
820 NodeMap node_map(&output);
821
822 EXPECT_EQ(output.node_size(), 8);
823
824 const NodeDef* id_node = node_map.GetNode("id");
825 ASSERT_NE(id_node, nullptr);
826 ASSERT_EQ(id_node->input_size(), 1);
827 EXPECT_EQ(id_node->input(0), HoistMulName("Add_6"));
828
829 const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
830 ASSERT_NE(mul_node, nullptr);
831 ASSERT_EQ(mul_node->input_size(), 2);
832 EXPECT_EQ(mul_node->input(0), "Placeholder");
833 EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
834
835 const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
836 ASSERT_NE(add_6_node, nullptr);
837 ASSERT_EQ(add_6_node->input_size(), 2);
838 EXPECT_EQ(add_6_node->input(0), HoistAddName("Add_4"));
839 EXPECT_EQ(add_6_node->input(1), HoistAddName("Add_5"));
840
841 const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
842 ASSERT_NE(add_4_node, nullptr);
843 EXPECT_EQ(add_4_node->op(), "Add");
844 ASSERT_EQ(2, add_4_node->input_size());
845 EXPECT_EQ(add_4_node->input(0), AggregationConstName("Add"));
846 EXPECT_EQ(add_4_node->input(1), AggregationConstName("Add_1"));
847
848 const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
849 ASSERT_NE(add_5_node, nullptr);
850 EXPECT_EQ(add_5_node->op(), "Add");
851 ASSERT_EQ(add_5_node->input_size(), 2);
852 EXPECT_EQ(add_5_node->input(0), AggregationConstName("Add"));
853 EXPECT_EQ(add_5_node->input(1), AggregationConstName("Add_1"));
854
855 const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
856 ASSERT_NE(add_const_node, nullptr);
857 EXPECT_EQ(add_const_node->op(), "Const");
858 ASSERT_EQ(add_const_node->input_size(), 1);
859 EXPECT_EQ(add_const_node->input(0), "^Placeholder");
860
861 const NodeDef* add_1_const_node =
862 node_map.GetNode(AggregationConstName("Add_1"));
863 ASSERT_NE(add_1_const_node, nullptr);
864 EXPECT_EQ(add_1_const_node->op(), "Const");
865 ASSERT_EQ(add_1_const_node->input_size(), 1);
866 EXPECT_EQ(add_1_const_node->input(0), "^Placeholder");
867 }
868
TEST_F(ArithmeticOptimizerTest,HoistFactorMul)869 TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
870 for (bool matching_shapes : {true, false}) {
871 for (bool use_addn : {true, false}) {
872 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
873 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
874 Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
875 Output y2 = matching_shapes
876 ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
877 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
878 Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
879 Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
880 Output id =
881 use_addn ? ops::Identity(s.WithOpName("id"),
882 ops::AddN(s.WithOpName("add"), {mul1, mul2}))
883 : ops::Identity(s.WithOpName("id"),
884 ops::Add(s.WithOpName("add"), mul1, mul2));
885
886 GrapplerItem item;
887 item.fetch = {"id"};
888 TF_CHECK_OK(s.ToGraphDef(&item.graph));
889 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
890 ASSERT_EQ(tensors_expected.size(), 1);
891 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
892 EnableOnlyHoistCommonFactor(&optimizer);
893
894 GraphDef output;
895 OptimizeTwice(&optimizer, &item, &output);
896
897 // We expect the following rewrite(s) to occur:
898 //
899 // Add Mul
900 // / \ / \
901 // Mul Mul -> x Add
902 // / \ / \ / \
903 // x y1 y2 x y1 y2
904 //
905 // If "root" op is AddN and shapes does not match, this rewrite is not
906 // possible and graph should stay intact.
907 NodeMap node_map(&output);
908
909 if (use_addn && !matching_shapes) {
910 VerifyGraphsMatch(item.graph, output, __LINE__);
911 } else {
912 EXPECT_EQ(output.node_size(), 9);
913
914 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
915 ASSERT_NE(new_add_node, nullptr) << "Hoisted Add node not found";
916 ASSERT_EQ(new_add_node->input_size(), 2);
917 EXPECT_EQ(new_add_node->input(0), "y1");
918 EXPECT_EQ(new_add_node->input(1), "y2");
919
920 const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
921 ASSERT_NE(new_mul_node, nullptr) << "Hoisted Mul node not found";
922 ASSERT_EQ(new_mul_node->input_size(), 2);
923 EXPECT_EQ(new_mul_node->input(0), "x");
924 EXPECT_EQ(new_mul_node->input(1), new_add_node->name());
925
926 const NodeDef* id_node = node_map.GetNode("id");
927 ASSERT_NE(id_node, nullptr) << "Id node not found";
928 EXPECT_EQ(id_node->name(), "id");
929 ASSERT_EQ(id_node->input_size(), 1);
930 EXPECT_EQ(id_node->input(0), HoistMulName("add"));
931 }
932 auto tensors = EvaluateNodes(output, item.fetch);
933 ASSERT_EQ(tensors.size(), 1);
934 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
935 }
936 }
937 }
938
TEST_F(ArithmeticOptimizerTest,HoistFactorDiv)939 TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
940 for (bool matching_shapes : {true, false}) {
941 for (bool use_addn : {true, false}) {
942 for (bool use_ints : {true, false}) {
943 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
944 Output x = use_ints
945 ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
946 : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
947 Output y1 = use_ints
948 ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
949 : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
950 Output y2;
951 if (matching_shapes) {
952 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
953 : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
954 } else {
955 y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
956 : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
957 }
958 Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
959 Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
960 Output id =
961 use_addn
962 ? ops::Identity(s.WithOpName("id"),
963 ops::AddN(s.WithOpName("add"), {div1, div2}))
964 : ops::Identity(s.WithOpName("id"),
965 ops::Add(s.WithOpName("add"), div1, div2));
966
967 GrapplerItem item;
968 item.fetch = {"id"};
969 TF_CHECK_OK(s.ToGraphDef(&item.graph));
970
971 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
972 ASSERT_EQ(tensors_expected.size(), 1);
973
974 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
975 EnableOnlyHoistCommonFactor(&optimizer);
976
977 GraphDef output;
978 OptimizeTwice(&optimizer, &item, &output);
979
980 // We expect the following rewrite(s) to occur:
981 //
982 // Add Div
983 // / \ / \
984 // Div Div -> Add x
985 // / \ / \ / \
986 // y1 x y2 x y1 y2
987 //
988 // If "root" op is AddN and shapes does not match, this rewrite is not
989 // possible and graph should stay intact.
990 NodeMap node_map(&output);
991
992 if ((use_addn && !matching_shapes) || use_ints) {
993 VerifyGraphsMatch(item.graph, output, __LINE__);
994 } else {
995 EXPECT_EQ(output.node_size(), 9);
996
997 const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
998 ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
999 ASSERT_EQ(new_add_node->input_size(), 2);
1000 EXPECT_EQ(new_add_node->input(0), "y1");
1001 EXPECT_EQ(new_add_node->input(1), "y2");
1002
1003 const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
1004 ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
1005 ASSERT_EQ(new_div_node->input_size(), 2);
1006 EXPECT_EQ(new_div_node->input(0), new_add_node->name());
1007 EXPECT_EQ(new_div_node->input(1), "x");
1008
1009 const NodeDef* id_node = node_map.GetNode("id");
1010 ASSERT_TRUE(id_node != nullptr) << "Id node not found";
1011 EXPECT_EQ("id", id_node->name());
1012 ASSERT_EQ(id_node->input_size(), 1);
1013 EXPECT_EQ(id_node->input(0), HoistDivName("add"));
1014 }
1015 auto tensors = EvaluateNodes(output, item.fetch);
1016 ASSERT_EQ(tensors.size(), 1);
1017 if (use_ints) {
1018 test::ExpectTensorEqual<int32>(tensors[0], tensors_expected[0]);
1019 } else {
1020 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1021 }
1022 }
1023 }
1024 }
1025 }
1026
TEST_F(ArithmeticOptimizerTest,FuseConjAndTranspose)1027 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
1028 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1029 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1030 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1031 Output z = ops::Complex(s.WithOpName("z"), re, im);
1032 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1033 Output conj = ops::Conj(s.WithOpName("conj"), z);
1034 Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
1035
1036 GrapplerItem item;
1037 item.fetch = {"trans"};
1038 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1039
1040 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1041 ASSERT_EQ(tensors_expected.size(), 1);
1042
1043 ArithmeticOptimizer optimizer;
1044 GraphDef output;
1045 OptimizeTwice(&optimizer, &item, &output);
1046 NodeMap node_map(&output);
1047
1048 EXPECT_EQ(output.node_size(), 7);
1049
1050 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1051 const string optimized_name = absl::StrCat(p, "_", "trans");
1052
1053 const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
1054 ASSERT_NE(trans_fused_node, nullptr);
1055 EXPECT_EQ(trans_fused_node->op(), "ConjugateTranspose");
1056 ASSERT_EQ(trans_fused_node->input_size(), 2);
1057 EXPECT_EQ(trans_fused_node->input(0), "z");
1058 EXPECT_EQ(trans_fused_node->input(1), "perm");
1059
1060 auto tensors = EvaluateNodes(output, item.fetch);
1061 ASSERT_EQ(tensors.size(), 1);
1062 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1063 }
1064
TEST_F(ArithmeticOptimizerTest,FuseConjAndConjugateTranspose)1065 TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
1066 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1067
1068 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1069 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1070 Output z = ops::Complex(s.WithOpName("z"), re, im);
1071 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1072 Output conj = ops::Conj(s.WithOpName("conj"), z);
1073 Output transp =
1074 ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
1075
1076 GrapplerItem item;
1077 item.fetch = {"conjugate_trans"};
1078 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1079
1080 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1081 ASSERT_EQ(tensors_expected.size(), 1);
1082
1083 ArithmeticOptimizer optimizer;
1084 GraphDef output;
1085 OptimizeTwice(&optimizer, &item, &output);
1086 NodeMap node_map(&output);
1087
1088 EXPECT_EQ(output.node_size(), 7);
1089
1090 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1091 const string optimized_name = absl::StrCat(p, "_", "conjugate_trans");
1092
1093 const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
1094 ASSERT_NE(conjugate_trans_fused_node, nullptr);
1095 EXPECT_EQ(conjugate_trans_fused_node->op(), "Transpose");
1096 ASSERT_EQ(conjugate_trans_fused_node->input_size(), 2);
1097 EXPECT_EQ(conjugate_trans_fused_node->input(0), "z");
1098 EXPECT_EQ(conjugate_trans_fused_node->input(1), "perm");
1099
1100 auto tensors = EvaluateNodes(output, item.fetch);
1101 ASSERT_EQ(tensors.size(), 1);
1102 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1103 }
1104
TEST_F(ArithmeticOptimizerTest,FuseTransposeAndConj)1105 TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
1106 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1107 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1108 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1109 Output z = ops::Complex(s.WithOpName("z"), re, im);
1110 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1111 Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
1112 Output conj = ops::Conj(s.WithOpName("conj"), trans);
1113
1114 GrapplerItem item;
1115 item.fetch = {"conj"};
1116 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1117
1118 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1119 ASSERT_EQ(tensors_expected.size(), 1);
1120
1121 ArithmeticOptimizer optimizer;
1122 GraphDef output;
1123 OptimizeTwice(&optimizer, &item, &output);
1124 NodeMap node_map(&output);
1125
1126 EXPECT_EQ(output.node_size(), 7);
1127
1128 const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
1129 const string optimized_name = absl::StrCat(p, "_", "conj");
1130
1131 const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
1132 ASSERT_NE(conj_fused_node, nullptr);
1133 EXPECT_EQ(conj_fused_node->op(), "ConjugateTranspose");
1134 ASSERT_EQ(conj_fused_node->input_size(), 2);
1135 EXPECT_EQ(conj_fused_node->input(0), "z");
1136 EXPECT_EQ(conj_fused_node->input(1), "perm");
1137
1138 auto tensors = EvaluateNodes(output, item.fetch);
1139 ASSERT_EQ(tensors.size(), 1);
1140 test::ExpectTensorEqual<complex64>(tensors[0], tensors_expected[0]);
1141 }
1142
TEST_F(ArithmeticOptimizerTest,FoldTransposeIntoMatMul)1143 TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
1144 for (const string matmul_type :
1145 {"MatMul", "SparseMatMul", "BatchMatMul", "BatchMatMulV2"}) {
1146 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1147
1148 Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1149 Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1150 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1151 Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
1152 Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
1153
1154 Output matmul;
1155 auto matmul_op = s.WithOpName("matmul");
1156 if (matmul_type == "MatMul") {
1157 matmul = ops::MatMul(matmul_op, trans_a, trans_b);
1158 } else if (matmul_type == "SparseMatMul") {
1159 matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
1160 } else if (matmul_type == "BatchMatMul") {
1161 matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
1162 } else if (matmul_type == "BatchMatMulV2") {
1163 matmul = ops::BatchMatMulV2(matmul_op, trans_a, trans_b);
1164 }
1165
1166 auto identity = ops::Identity(s.WithOpName("identity"), matmul);
1167
1168 GrapplerItem item;
1169 item.fetch = {"identity"};
1170 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1171
1172 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1173 ASSERT_EQ(tensors_expected.size(), 1);
1174
1175 ArithmeticOptimizer optimizer;
1176 EnableOnlyFoldTransposeIntoMatMul(&optimizer);
1177 GraphDef output;
1178 OptimizeTwice(&optimizer, &item, &output);
1179 NodeMap node_map(&output);
1180
1181 EXPECT_EQ(output.node_size(), 8);
1182
1183 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1184 const string optimized_name = absl::StrCat(p, "_", "matmul");
1185
1186 const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
1187 ASSERT_NE(matmul_fused_node, nullptr);
1188 ASSERT_EQ(matmul_fused_node->input_size(), 2);
1189 EXPECT_EQ(matmul_fused_node->input(0), "a");
1190 EXPECT_EQ(matmul_fused_node->input(1), "b");
1191
1192 if (matmul_type == "BatchMatMul" || matmul_type == "BatchMatMulV2") {
1193 EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
1194 EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
1195 } else {
1196 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
1197 EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
1198 }
1199
1200 const NodeDef* identity_node = node_map.GetNode("identity");
1201 ASSERT_NE(identity_node, nullptr);
1202 ASSERT_EQ(identity_node->input_size(), 1);
1203 EXPECT_EQ(identity_node->input(0), optimized_name);
1204
1205 auto tensors = EvaluateNodes(output, item.fetch);
1206 ASSERT_EQ(tensors.size(), 1);
1207 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1208 }
1209 }
1210
TEST_F(ArithmeticOptimizerTest,FoldConjugateTransposeIntoBatchMatMul)1211 TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
1212 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1213
1214 Output re_a =
1215 ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1216 Output im_a =
1217 ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
1218 Output re_b =
1219 ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1220 Output im_b =
1221 ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
1222 Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
1223 Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
1224 Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
1225 Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
1226 Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
1227 Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
1228 Output identity = ops::Identity(s.WithOpName("identity"), matmul);
1229
1230 GrapplerItem item;
1231 item.fetch = {"identity"};
1232 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1233
1234 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
1235 ASSERT_EQ(tensors_expected.size(), 1);
1236
1237 ArithmeticOptimizer optimizer;
1238 GraphDef output;
1239 OptimizeTwice(&optimizer, &item, &output);
1240
1241 NodeMap node_map(&output);
1242 EXPECT_EQ(output.node_size(), 12);
1243
1244 const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
1245 const string optimized_name = absl::StrCat(p, "_", "matmul");
1246
1247 const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
1248 ASSERT_NE(optimized_matmul, nullptr);
1249 ASSERT_EQ(optimized_matmul->input_size(), 2);
1250 EXPECT_EQ(optimized_matmul->input(0), "a");
1251 EXPECT_EQ(optimized_matmul->input(1), "b");
1252 EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
1253 EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
1254
1255 auto tensors = EvaluateNodes(output, item.fetch);
1256 ASSERT_EQ(tensors.size(), 1);
1257 test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
1258 }
1259
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshape)1260 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
1261 for (bool is_broadcastto : {false, true}) {
1262 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1263 Output inputs =
1264 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1265 Output inputs_shape = ops::Shape(s, inputs);
1266 // The target shape of the reshape is the concatenation of `batch_size` and
1267 // [3,28,28].
1268 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1269 ops::Const(s, {1}, {1}));
1270 Output target_shape = ops::Concat(
1271 s.WithOpName("target_shape"),
1272 {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
1273 if (is_broadcastto) {
1274 Output outputs = ops::Identity(s.WithOpName("outputs"),
1275 ops::BroadcastTo(s, inputs, target_shape));
1276 } else {
1277 Output outputs = ops::Identity(s.WithOpName("outputs"),
1278 ops::Reshape(s, inputs, target_shape));
1279 }
1280
1281 GrapplerItem item;
1282 item.fetch = {"outputs"};
1283 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1284 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1285 auto tensors_expected =
1286 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
1287 ASSERT_EQ(tensors_expected.size(), 1);
1288
1289 GraphDef output;
1290 ArithmeticOptimizer optimizer;
1291 EnableOnlyRemoveRedundantReshape(&optimizer);
1292 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1293
1294 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1295 EXPECT_EQ(CountOpNodes(output, "BroadcastTo"), 0);
1296 auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
1297 ASSERT_EQ(tensors.size(), 1);
1298 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1299 }
1300 }
1301
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes)1302 TEST_F(ArithmeticOptimizerTest,
1303 RemoveRedundantReshapeIdentityReshapeBetweenSymbolicShapes) {
1304 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1305 Output inputs =
1306 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
1307 Output inputs_shape = ops::Shape(s, inputs);
1308 // The target shape of the reshape is the concatenation of `batch_size`, 3,
1309 // `height, and `width`.
1310 Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
1311 ops::Const(s, {1}, {1}));
1312 Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
1313 ops::Const(s, {1}, {1}));
1314 Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
1315 ops::Const(s, {1}, {1}));
1316 Output target_shape =
1317 ops::Concat(s.WithOpName("target_shape"),
1318 {batch_size, ops::Const(s, {3}, {1}), height, width},
1319 ops::Const(s, {0}, {}));
1320 Output reshape = ops::Reshape(s, inputs, target_shape);
1321 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1322
1323 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
1324 GrapplerItem item;
1325 item.fetch = {"outputs"};
1326 item.feed = {{"Placeholder", x_t}};
1327 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1328
1329 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1330 ASSERT_EQ(tensors_expected.size(), 1);
1331
1332 GraphDef output;
1333 // Assume valid feed shape in aggressive mode.
1334 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1335 EnableOnlyRemoveRedundantReshape(&optimizer);
1336 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1337
1338 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1339 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1340 ASSERT_EQ(tensors.size(), 1);
1341 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1342 }
1343
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotAssumeValidFeeds)1344 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotAssumeValidFeeds) {
1345 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1346 Output inputs =
1347 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1348 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1349 Output reshape = ops::Reshape(s, inputs, target_shape);
1350 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1351
1352 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1353 GrapplerItem item;
1354 item.fetch = {"outputs"};
1355 item.feed = {{"Placeholder", x_t}};
1356 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1357
1358 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1359 ASSERT_EQ(tensors_expected.size(), 1);
1360
1361 GraphDef output;
1362 ArithmeticOptimizer optimizer;
1363 EnableOnlyRemoveRedundantReshape(&optimizer);
1364 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1365
1366 // The reshape is preserved because the shape of the placeholder can be
1367 // different from the shape of the actual feed.
1368 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1369
1370 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1371 ASSERT_EQ(tensors.size(), 1);
1372 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1373 }
1374
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode)1375 TEST_F(ArithmeticOptimizerTest,
1376 RemoveRedundantReshapeAssumeValidFeedsInAggressiveMode) {
1377 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1378 Output inputs =
1379 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
1380 Output target_shape = ops::Const(s, {4, 3, 28, 28}, {4});
1381 Output reshape = ops::Reshape(s, inputs, target_shape);
1382 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1383
1384 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 3, 28, 28}));
1385 GrapplerItem item;
1386 item.fetch = {"outputs"};
1387 item.feed = {{"Placeholder", x_t}};
1388 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1389
1390 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1391 ASSERT_EQ(tensors_expected.size(), 1);
1392
1393 GraphDef output;
1394 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1395 EnableOnlyRemoveRedundantReshape(&optimizer);
1396 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1397
1398 EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
1399 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1400 ASSERT_EQ(tensors.size(), 1);
1401 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1402 }
1403
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshape)1404 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeNotIdentityReshape) {
1405 // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
1406 // be from [4,3,28,28] to [8,6,28,28].
1407 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1408 Output inputs =
1409 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
1410 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
1411 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1412
1413 GrapplerItem item;
1414 item.fetch = {"outputs"};
1415 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1416 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28}));
1417 item.feed = {{"Placeholder", x_t}};
1418 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1419 ASSERT_EQ(tensors_expected.size(), 1);
1420
1421 GraphDef output;
1422 ArithmeticOptimizer optimizer;
1423 EnableOnlyRemoveRedundantReshape(&optimizer);
1424 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1425
1426 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1427 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1428 ASSERT_EQ(tensors.size(), 1);
1429 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1430 }
1431
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes)1432 TEST_F(ArithmeticOptimizerTest,
1433 RemoveRedundantReshapeNotIdentityReshapeTooManyUnknownDimSizes) {
1434 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1435 Output inputs =
1436 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
1437 Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
1438 Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
1439
1440 GrapplerItem item;
1441 item.fetch = {"outputs"};
1442 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1443
1444 GraphDef output;
1445 ArithmeticOptimizer optimizer;
1446 EnableOnlyRemoveRedundantReshape(&optimizer);
1447 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1448
1449 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1450 }
1451
TEST_F(ArithmeticOptimizerTest,RemoveRedundantReshapeCombineReshapes)1452 TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeCombineReshapes) {
1453 for (bool include_unary_chain : {false, true}) {
1454 // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The
1455 // two reshapes should be combined.
1456 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1457 Output nchw_vect_c =
1458 ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_FLOAT,
1459 ops::Placeholder::Shape({8, 3, 28, 28, 4}));
1460 Output transpose =
1461 ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
1462 ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
1463 Output nhwc = ops::Reshape(
1464 s.WithOpName("nhwc"), transpose,
1465 ops::Const(
1466 s.WithControlDependencies(nchw_vect_c).WithOpName("nhwc_shape"),
1467 {8, 28, 28, 12}, {4}));
1468 Output flatten = ops::Reshape(
1469 s.WithOpName("flatten"),
1470 (include_unary_chain ? ops::Cos(s.WithOpName("Cos"), nhwc) : nhwc),
1471 ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
1472 Output output0 = ops::Identity(s.WithOpName("output0"), flatten);
1473 Output output1 = ops::Identity(s.WithOpName("output1"), flatten);
1474
1475 GraphDef graph;
1476 TF_CHECK_OK(s.ToGraphDef(&graph));
1477 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 3, 28, 28, 4}));
1478 auto eval =
1479 EvaluateNodes(graph, {"output0", "nhwc"}, {{"nchw_vect_c", x_t}});
1480
1481 ASSERT_EQ(eval.size(), 2);
1482 auto expected_output_t = eval[0];
1483 auto nhwc_t = eval[1];
1484
1485 {
1486 GrapplerItem item;
1487 item.graph = graph;
1488 item.fetch = {"output0", "output1"};
1489 item.feed = {{"nchw_vect_c", x_t}};
1490
1491 GraphDef output;
1492 ArithmeticOptimizer optimizer;
1493 EnableOnlyRemoveRedundantReshape(&optimizer);
1494 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1495
1496 EXPECT_EQ(CountOpNodes(output, "Reshape"), 1);
1497 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1498 ASSERT_EQ(tensors.size(), 2);
1499 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1500 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1501 }
1502
1503 // Test when the first reshape node output is the feed tensor.
1504 // (Expected no reshape removal to happen.)
1505 {
1506 GrapplerItem item;
1507 item.graph = graph;
1508 item.fetch = {"output0", "output1"};
1509 item.feed = {{"nhwc", nhwc_t}};
1510
1511 GraphDef output;
1512 ArithmeticOptimizer optimizer;
1513 EnableOnlyRemoveRedundantReshape(&optimizer);
1514 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1515
1516 EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1517 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1518 ASSERT_EQ(tensors.size(), 2);
1519 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1520 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1521 }
1522
1523 // Test when the first reshape node output is consumed by multiple nodes
1524 // (Expected no reshape removal to happen.)
1525 {
1526 Output output2 = ops::Identity(s.WithOpName("output2"), nhwc);
1527 GraphDef graph;
1528 TF_CHECK_OK(s.ToGraphDef(&graph));
1529 GrapplerItem item;
1530 item.graph = graph;
1531 item.fetch = {"output0", "output1", "output2"};
1532 item.feed = {{"nchw_vect_c", x_t}};
1533
1534 GraphDef output;
1535 ArithmeticOptimizer optimizer;
1536 EnableOnlyRemoveRedundantReshape(&optimizer);
1537 OptimizeTwiceAndPrune(&optimizer, &item, &output);
1538
1539 EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
1540 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1541 ASSERT_EQ(tensors.size(), 3);
1542 test::ExpectTensorEqual<float>(tensors[0], expected_output_t);
1543 test::ExpectTensorEqual<float>(tensors[1], expected_output_t);
1544 test::ExpectTensorEqual<float>(tensors[2], nhwc_t);
1545 }
1546 }
1547 }
1548
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsCast)1549 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsCast) {
1550 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1551 Output nhwc_uint8 =
1552 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1553 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1554 Output nchw_fp32 =
1555 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1556 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1557
1558 GrapplerItem item;
1559 item.fetch = {"outputs"};
1560 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1561
1562 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1563 auto tensors_expected =
1564 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1565 ASSERT_EQ(tensors_expected.size(), 1);
1566
1567 GraphDef output;
1568 ArithmeticOptimizer optimizer;
1569 OptimizeAndPrune(&optimizer, &item, &output);
1570
1571 const NodeDef* transpose_node = nullptr;
1572 for (const NodeDef& node : output.node()) {
1573 if (node.op() == "Transpose") {
1574 EXPECT_EQ(transpose_node, nullptr);
1575 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1576 transpose_node = &node;
1577 }
1578 }
1579 ASSERT_NE(transpose_node, nullptr);
1580
1581 for (const NodeDef& node : output.node()) {
1582 if (node.op() == "Cast") {
1583 ASSERT_EQ(node.input_size(), 1);
1584 EXPECT_EQ(transpose_node->name(), NodeName(node.input(0)));
1585 }
1586 }
1587
1588 auto tensors =
1589 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1590 ASSERT_EQ(tensors.size(), 1);
1591 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1592 }
1593
TEST_F(ArithmeticOptimizerTest,ReorderS2DCastProducerIsCast)1594 TEST_F(ArithmeticOptimizerTest, ReorderS2DCastProducerIsCast) {
1595 // TODO(jingyue): Evaluate S2D+Cast on GPU as well. We can't simply put nodes
1596 // under a /GPU:0 scope, because this test would fail if the testing machine
1597 // doesn't have a GPU. Maybe EvaluateNodes should allow soft placement?
1598 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1599 Output outputs =
1600 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1601 outputs = ops::Cast(s, outputs, DT_FLOAT);
1602 outputs = ops::SpaceToDepth(s, outputs, 2);
1603 outputs = ops::Identity(s.WithOpName("outputs"), outputs);
1604
1605 GrapplerItem item;
1606 item.fetch = {"outputs"};
1607 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1608
1609 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1610 auto tensors_expected =
1611 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1612 ASSERT_EQ(tensors_expected.size(), 1);
1613
1614 GraphDef output;
1615 ArithmeticOptimizer optimizer;
1616 OptimizeAndPrune(&optimizer, &item, &output);
1617
1618 const NodeDef* s2d_node = nullptr;
1619 for (const NodeDef& node : output.node()) {
1620 if (node.op() == "SpaceToDepth") {
1621 EXPECT_EQ(s2d_node, nullptr);
1622 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1623 s2d_node = &node;
1624 }
1625 }
1626 ASSERT_NE(s2d_node, nullptr);
1627
1628 for (const NodeDef& node : output.node()) {
1629 if (node.op() == "Cast") {
1630 ASSERT_EQ(node.input_size(), 1);
1631 EXPECT_EQ(s2d_node->name(), NodeName(node.input(0)));
1632 }
1633 }
1634
1635 auto tensors =
1636 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1637 ASSERT_EQ(tensors.size(), 1);
1638 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1639 }
1640
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastProducerIsTranspose)1641 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastProducerIsTranspose) {
1642 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1643 Output nhwc_fp32 =
1644 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1645 Output nchw_fp32 =
1646 ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
1647 Output nchw_uint8 = ops::Cast(s, nchw_fp32, DT_UINT8);
1648 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1649
1650 GrapplerItem item;
1651 item.fetch = {"outputs"};
1652 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1653
1654 auto input_t =
1655 GenerateConstantTensor<DT_FLOAT>(TensorShape({8, 28, 28, 3}), 42.0f);
1656 auto tensors_expected =
1657 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1658 ASSERT_EQ(tensors_expected.size(), 1);
1659
1660 GraphDef output;
1661 ArithmeticOptimizer optimizer;
1662 OptimizeAndPrune(&optimizer, &item, &output);
1663
1664 const NodeDef* cast_node = nullptr;
1665 for (const NodeDef& node : output.node()) {
1666 if (node.op() == "Cast") {
1667 EXPECT_EQ(cast_node, nullptr);
1668 cast_node = &node;
1669 ASSERT_EQ(node.input_size(), 1);
1670 EXPECT_EQ(NodeName(node.input(0)), "Placeholder");
1671 }
1672 }
1673 ASSERT_NE(cast_node, nullptr);
1674
1675 for (const NodeDef& node : output.node()) {
1676 if (node.op() == "Transpose") {
1677 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1678 ASSERT_EQ(node.input_size(), 2);
1679 EXPECT_EQ(cast_node->name(), NodeName(node.input(0)));
1680 }
1681 }
1682
1683 auto tensors =
1684 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1685 ASSERT_EQ(tensors.size(), 1);
1686 test::ExpectTensorEqual<uint8>(tensors[0], tensors_expected[0]);
1687 }
1688
TEST_F(ArithmeticOptimizerTest,ReorderTransposeReverseCast)1689 TEST_F(ArithmeticOptimizerTest, ReorderTransposeReverseCast) {
1690 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1691 Output nhwc_uint8 =
1692 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1693 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1694 Output nhwc_fp32_reversed =
1695 ops::Reverse(s, nhwc_fp32, ops::Const(s, {0}, {1}));
1696 Output nchw_fp32_reversed =
1697 ops::Transpose(s, nhwc_fp32_reversed, ops::Const(s, {0, 3, 1, 2}, {4}));
1698
1699 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32_reversed);
1700
1701 GrapplerItem item;
1702 item.fetch = {"outputs"};
1703 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1704
1705 auto input_t = GenerateRandomTensor<DT_UINT8>(TensorShape({8, 28, 28, 3}));
1706 auto tensors_expected =
1707 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1708 ASSERT_EQ(tensors_expected.size(), 1);
1709
1710 GraphDef output;
1711 ArithmeticOptimizer optimizer;
1712 OptimizeAndPrune(&optimizer, &item, &output);
1713
1714 const NodeDef* reverse_node = nullptr;
1715 const NodeDef* transpose_node = nullptr;
1716 const NodeDef* cast_node = nullptr;
1717 for (const NodeDef& node : output.node()) {
1718 if (node.op() == "Transpose") {
1719 EXPECT_EQ(transpose_node, nullptr);
1720 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1721 transpose_node = &node;
1722 } else if (node.op() == "ReverseV2") {
1723 EXPECT_EQ(reverse_node, nullptr);
1724 EXPECT_EQ(node.attr().at("T").type(), DT_UINT8);
1725 reverse_node = &node;
1726 } else if (node.op() == "Cast") {
1727 cast_node = &node;
1728 }
1729 }
1730 ASSERT_NE(cast_node, nullptr);
1731 ASSERT_NE(reverse_node, nullptr);
1732 ASSERT_NE(transpose_node, nullptr);
1733 ASSERT_EQ(reverse_node->input_size(), 2);
1734 EXPECT_EQ(NodeName(reverse_node->input(0)), "Placeholder");
1735 ASSERT_EQ(transpose_node->input_size(), 2);
1736 EXPECT_EQ(NodeName(transpose_node->input(0)), reverse_node->name());
1737 ASSERT_EQ(cast_node->input_size(), 1);
1738 EXPECT_EQ(NodeName(cast_node->input(0)), transpose_node->name());
1739
1740 auto tensors =
1741 EvaluateNodes(item.graph, item.fetch, {{"Placeholder", input_t}});
1742 ASSERT_EQ(tensors.size(), 1);
1743 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
1744 }
1745
TEST_F(ArithmeticOptimizerTest,ReorderTransposeCastCheckNumericsToIdentity)1746 TEST_F(ArithmeticOptimizerTest, ReorderTransposeCastCheckNumericsToIdentity) {
1747 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1748 Output nhwc_uint8 =
1749 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1750 Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
1751 Output nchw_fp32 = ops::CheckNumerics(s, nhwc_fp32, "foo");
1752 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1753
1754 GrapplerItem item;
1755 item.fetch = {"outputs"};
1756 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1757
1758 GraphDef output;
1759 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1760 CompareGraphs(item.graph, output);
1761 }
1762
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsCast)1763 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsCast) {
1764 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1765 Output nhwc_fp32 =
1766 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
1767 Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
1768 Output nchw_uint8 =
1769 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1770 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
1771
1772 GrapplerItem item;
1773 item.fetch = {"outputs"};
1774 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1775
1776 GraphDef output;
1777 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1778 CompareGraphs(item.graph, output);
1779 }
1780
TEST_F(ArithmeticOptimizerTest,NoReorderTransposeCastProducerIsTranspose)1781 TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCastProducerIsTranspose) {
1782 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/CPU:0");
1783 Output nhwc_uint8 =
1784 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
1785 Output nchw_uint8 =
1786 ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
1787 Output nchw_fp32 = ops::Cast(s, nchw_uint8, DT_FLOAT);
1788 Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
1789
1790 GrapplerItem item;
1791 item.fetch = {"outputs"};
1792 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1793
1794 GraphDef output;
1795 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1796 CompareGraphs(item.graph, output);
1797 }
1798
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposes)1799 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
1800 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1801 Output inputs_shape =
1802 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1803 Output inputs =
1804 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1805 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1806 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1807 Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
1808 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
1809 Output transpose2 =
1810 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
1811 Output transpose3 = ops::Transpose(s.WithOpName("transpose3"), inputs, perm3);
1812 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1813 Output id2 = ops::Identity(s.WithOpName("id2"), transpose3);
1814
1815 GrapplerItem item;
1816 item.fetch = {"id1", "id2"};
1817 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1818
1819 GraphDef output;
1820 ArithmeticOptimizer optimizer;
1821 EnableOnlyRemoveIdentityTranspose(&optimizer);
1822 OptimizeAndPrune(&optimizer, &item, &output);
1823
1824 std::set<string> nodes_after_optimization;
1825 for (const NodeDef& node : output.node()) {
1826 nodes_after_optimization.insert(node.name());
1827 }
1828 EXPECT_EQ(nodes_after_optimization,
1829 std::set<string>({"id1", "id2", "inputs_shape", "inputs"}));
1830 }
1831
TEST_F(ArithmeticOptimizerTest,RemoveIdentityConjugateTransposes)1832 TEST_F(ArithmeticOptimizerTest, RemoveIdentityConjugateTransposes) {
1833 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1834 Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
1835 Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
1836 Output z = ops::Complex(s.WithOpName("z"), re, im);
1837 Output perm = ops::Const(s.WithOpName("perm"), {0, 1}, {2});
1838 Output transpose = ops::ConjugateTranspose(s.WithOpName("trans"), z, perm);
1839 Output id = ops::Identity(s.WithOpName("id"), transpose);
1840
1841 GrapplerItem item;
1842 item.fetch = {"id"};
1843 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1844
1845 GraphDef output;
1846 ArithmeticOptimizer optimizer;
1847 EnableOnlyRemoveIdentityTranspose(&optimizer);
1848 OptimizeAndPrune(&optimizer, &item, &output);
1849 NodeMap node_map(&output);
1850
1851 EXPECT_EQ(output.node_size(), 5);
1852
1853 const string p = "ArithmeticOptimizer/RemoveIdentityTranspose";
1854 const string optimized_name = absl::StrCat(p, "_", "trans");
1855
1856 const NodeDef* conj = node_map.GetNode(optimized_name);
1857 ASSERT_NE(conj, nullptr);
1858 EXPECT_EQ(conj->op(), "Conj");
1859 ASSERT_EQ(conj->input_size(), 1);
1860 EXPECT_EQ(conj->input(0), "z");
1861 }
1862
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesMultipleOutputs)1863 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesMultipleOutputs) {
1864 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1865 Output inputs_shape =
1866 ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
1867 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
1868 ops::Placeholder::Shape({8, 12, 28, 28}));
1869 OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
1870 Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
1871 Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
1872 Output branch0 = split[0];
1873 Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
1874 Output branch2 = split[2];
1875 Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
1876 Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
1877
1878 GrapplerItem item;
1879 item.fetch = {"outputs"};
1880 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1881
1882 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({8, 12, 28, 28}));
1883 item.feed = {{"inputs", x_t}};
1884 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1885 ASSERT_EQ(tensors_expected.size(), 1);
1886
1887 GraphDef output;
1888 ArithmeticOptimizer optimizer;
1889 EnableOnlyRemoveIdentityTranspose(&optimizer);
1890 OptimizeAndPrune(&optimizer, &item, &output);
1891
1892 for (const NodeDef& node : output.node()) {
1893 if (node.op() == "Concat") {
1894 ASSERT_EQ(node.input_size(), 3);
1895 EXPECT_EQ(node.input(0), "Split");
1896 EXPECT_EQ(node.input(1), "Split:1");
1897 EXPECT_EQ(node.input(2), "Split:2");
1898 }
1899 }
1900
1901 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1902 ASSERT_EQ(tensors.size(), 1);
1903 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1904 }
1905
TEST_F(ArithmeticOptimizerTest,RemoveTransposesWithControlDependency)1906 TEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
1907 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1908 Output inputs =
1909 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
1910 Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
1911 Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
1912 Output outputs =
1913 ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
1914 ops::Const(s.WithOpName("outputs_const"), 1.0f));
1915
1916 GrapplerItem item;
1917 item.fetch = {"outputs"};
1918 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1919
1920 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 3}));
1921 item.feed = {{"Placeholder", x_t}};
1922 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
1923 ASSERT_EQ(tensors_expected.size(), 1);
1924
1925 GraphDef output;
1926 ArithmeticOptimizer optimizer;
1927 EnableOnlyRemoveIdentityTranspose(&optimizer);
1928 OptimizeAndPrune(&optimizer, &item, &output);
1929
1930 NodeMap node_map(&output);
1931 const NodeDef* outputs_node = node_map.GetNode("outputs");
1932 ASSERT_EQ(outputs_node->input_size(), 2);
1933 EXPECT_EQ(outputs_node->input(0), "outputs_const");
1934 EXPECT_EQ(outputs_node->input(1), "^Placeholder");
1935
1936 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
1937 ASSERT_EQ(tensors.size(), 1);
1938 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
1939 }
1940
TEST_F(ArithmeticOptimizerTest,NotRemoveTransposes)1941 TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
1942 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1943 Output inputs_shape =
1944 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1945 Output inputs =
1946 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1947 Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
1948 Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
1949 Output transpose2 =
1950 ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
1951 Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
1952
1953 GrapplerItem item;
1954 item.fetch = {"outputs"};
1955 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1956
1957 GraphDef output;
1958 ArithmeticOptimizer optimizer;
1959 EnableOnlyRemoveIdentityTranspose(&optimizer);
1960 OptimizeAndPrune(&optimizer, &item, &output);
1961
1962 EXPECT_EQ(output.node_size(), 6);
1963 }
1964
TEST_F(ArithmeticOptimizerTest,RemoveIdentityTransposesThroughChain)1965 TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
1966 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1967 Output inputs_shape =
1968 ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
1969 Output inputs =
1970 ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
1971 Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
1972 Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
1973 Output transpose1 = ops::Transpose(
1974 s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
1975 Output identity = ops::Identity(s.WithOpName("id"), transpose1);
1976 Output transpose2 =
1977 ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
1978 Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
1979
1980 GrapplerItem item;
1981 item.fetch = {"id1"};
1982 TF_CHECK_OK(s.ToGraphDef(&item.graph));
1983
1984 GraphDef output;
1985 ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
1986 EnableOnlyRemoveIdentityTranspose(&optimizer);
1987 OptimizeAndPrune(&optimizer, &item, &output);
1988
1989 std::set<string> nodes_after_optimization;
1990 for (const NodeDef& node : output.node()) {
1991 nodes_after_optimization.insert(node.name());
1992 if (node.name() == "id") {
1993 ASSERT_EQ(node.input_size(), 1);
1994 EXPECT_EQ(node.input(0), "inputs");
1995 }
1996 if (node.name() == "id1") {
1997 ASSERT_EQ(node.input_size(), 1);
1998 EXPECT_EQ(node.input(0), "id");
1999 }
2000 }
2001 EXPECT_EQ(nodes_after_optimization,
2002 std::set<string>({"id", "id1", "inputs_shape", "inputs"}));
2003 }
2004
TEST_F(ArithmeticOptimizerTest,FoldMulToTransposeConv)2005 TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
2006 for (bool swap_inputs : {false, true}) {
2007 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2008 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2009 ops::Placeholder::Shape({1, 28, 28, 3}));
2010 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2011 Output scaled_inputs = ops::Multiply(s.WithOpName("scaled_inputs"),
2012 swap_inputs ? scale : inputs,
2013 swap_inputs ? inputs : scale);
2014 Output perm_nhwc_to_nchw =
2015 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2016 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2017 scaled_inputs, perm_nhwc_to_nchw);
2018 Output weights = ops::Const(s.WithOpName("weights"),
2019 Input::Initializer(127.0f, {5, 5, 3, 4}));
2020 Output conv =
2021 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2022 "VALID", ops::Conv2D::DataFormat("NCHW"));
2023 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2024
2025 GrapplerItem item;
2026 item.fetch = {"outputs"};
2027 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2028
2029 // LOG(INFO) << "Before:\n" << item.graph.DebugString();
2030 GraphDef output;
2031 ArithmeticOptimizer optimizer;
2032 EnableOnlyFoldMultipleIntoConv(&optimizer);
2033 OptimizeTwiceAndPrune(&optimizer, &item, &output);
2034
2035 // LOG(INFO) << "After:\n" << output.DebugString();
2036 NodeMap node_map(&output);
2037 // `conv` is now a folded convolution with scaled weights.
2038 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2039 ASSERT_NE(folded_conv, nullptr);
2040
2041 const NodeDef* folded_conv_weights =
2042 node_map.GetNode(folded_conv->input(1));
2043 ASSERT_NE(folded_conv_weights, nullptr);
2044 EXPECT_EQ(folded_conv_weights->op(), "Mul");
2045
2046 // Its input should be a transpose of `inputs`.
2047 const NodeDef* transpose =
2048 node_map.GetNode(NodeName(folded_conv->input(0)));
2049 ASSERT_NE(transpose, nullptr);
2050 ASSERT_EQ(transpose->input_size(), 2);
2051 EXPECT_EQ(transpose->input(0), "inputs");
2052 }
2053 }
2054
TEST_F(ArithmeticOptimizerTest,NotFoldMulAcrossPreservedTranspose)2055 TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
2056 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2057 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2058 ops::Placeholder::Shape({8, 28, 28, 3}));
2059 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2060 Output scaled_inputs =
2061 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2062 Output perm_nhwc_to_nchw =
2063 ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
2064 Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
2065 scaled_inputs, perm_nhwc_to_nchw);
2066 Output weights = ops::Const(s.WithOpName("weights"),
2067 Input::Initializer(127.0f, {5, 5, 3, 16}));
2068 Output conv =
2069 ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
2070 "VALID", ops::Conv2D::DataFormat("NCHW"));
2071 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2072
2073 Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
2074 memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
2075 inputs_nchw_tensor.tensor_data().size());
2076
2077 GrapplerItem item;
2078 item.fetch = {"outputs"};
2079 item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
2080 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2081
2082 GraphDef output;
2083 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2084
2085 item.graph.Swap(&output);
2086 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2087
2088 NodeMap node_map(&output);
2089 const NodeDef* inputs_nchw_node_def =
2090 node_map.GetNode(inputs_nchw.node()->name());
2091 ASSERT_NE(inputs_nchw_node_def, nullptr);
2092 ASSERT_EQ(inputs_nchw_node_def->input_size(), 2);
2093 EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
2094 scaled_inputs.node()->name());
2095 }
2096
TEST_F(ArithmeticOptimizerTest,FoldMulToConv)2097 TEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
2098 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2099 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
2100 ops::Placeholder::Shape({8, 28, 28, 28, 3}));
2101 Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
2102 Output scaled_inputs =
2103 ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
2104 Output weights = ops::Const(s.WithOpName("weights"),
2105 Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
2106 Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
2107 {1, 1, 1, 1, 1}, "VALID");
2108 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2109
2110 GrapplerItem item;
2111 item.fetch = {"outputs"};
2112 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2113
2114 GraphDef output;
2115 TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
2116
2117 item.graph.Swap(&output);
2118 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
2119
2120 NodeMap node_map(&output);
2121 // `conv` is now a folded convolution on `inputs` and scaled weights.
2122 const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
2123 ASSERT_NE(folded_conv, nullptr);
2124 ASSERT_EQ(folded_conv->input_size(), 2);
2125 CHECK_EQ(NodeName(folded_conv->input(0)), inputs.node()->name());
2126 const NodeDef* folded_conv_input_1 =
2127 node_map.GetNode(NodeName(folded_conv->input(1)));
2128 ASSERT_NE(folded_conv_input_1, nullptr);
2129 CHECK_EQ(folded_conv_input_1->op(), "Mul");
2130 }
2131
TEST_F(ArithmeticOptimizerTest,OptimizeCastMulTransposeConv)2132 TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
2133 // This unit test exercises two optimizations, folding mul into conv, and
2134 // reordering cast and transpose.
2135 //
2136 // Conv2D(Transpose(Mul(Cast(I), S)), W)
2137 // =>
2138 // Conv2D(Transpose(Cast(I)), W*S)
2139 // =>
2140 // Conv2D(Cast(Transpose(I)), W*S)
2141 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2142
2143 Output inputs =
2144 ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
2145 Output cast = ops::Cast(s, inputs, DT_FLOAT);
2146 Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
2147 Output transpose =
2148 ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
2149 Output weights = ops::Const(s.WithOpName("weights"),
2150 Input::Initializer(127.0f, {5, 5, 3, 16}));
2151 Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
2152 ops::Conv2D::DataFormat("NCHW"));
2153 Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
2154
2155 GrapplerItem item;
2156 item.fetch = {"outputs"};
2157 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2158
2159 GraphDef output;
2160 ArithmeticOptimizer optimizer; // all optimization stages are on
2161 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2162 NodeMap node_map(&output);
2163
2164 // Expected names for reordered cast and transpose.
2165 const string p = "ArithmeticOptimizer/ReorderCastLikeAndValuePreserving_";
2166 const string optimized_cast_name = absl::StrCat(p, "float_Cast");
2167 const string optimized_transpose_name = absl::StrCat(p, "uint8_Transpose");
2168
2169 // Expected names for folded multiply and conv.
2170 const string optimized_weights =
2171 "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
2172
2173 const NodeDef* inputs_node = node_map.GetNode("Placeholder");
2174 const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
2175 const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
2176
2177 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2178 const NodeDef* conv_node = node_map.GetNode("Conv2D");
2179
2180 ASSERT_NE(inputs_node, nullptr);
2181 ASSERT_NE(transpose_node, nullptr);
2182 ASSERT_NE(cast_node, nullptr);
2183 ASSERT_NE(weights_node, nullptr);
2184 ASSERT_NE(conv_node, nullptr);
2185
2186 EXPECT_EQ(output.node_size(), 7);
2187 ASSERT_EQ(transpose_node->input_size(), 2);
2188 EXPECT_EQ(transpose_node->input(0), inputs_node->name());
2189 ASSERT_EQ(cast_node->input_size(), 1);
2190 EXPECT_EQ(cast_node->input(0), transpose_node->name());
2191 ASSERT_EQ(conv_node->input_size(), 2);
2192 EXPECT_EQ(conv_node->input(0), cast_node->name());
2193 EXPECT_EQ(conv_node->input(1), weights_node->name());
2194 }
2195
TEST_F(ArithmeticOptimizerTest,OptimizeMultipleMulTransposeConv)2196 TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
2197 // This unit test exercises optimization of folding mul into conv for
2198 // multiple nodes in the graph.
2199 tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/cpu:0");
2200
2201 GrapplerItem item;
2202 Output conv[2];
2203
2204 for (int i = 0; i < 2; ++i) {
2205 Output inputs =
2206 ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
2207 Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
2208 Output weights = ops::Const(s.WithOpName("weights"),
2209 Input::Initializer(127.0f, {5, 5, 3, 16}));
2210 conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
2211 ops::Conv2D::DataFormat("NCHW"));
2212 }
2213 Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
2214
2215 item.fetch = {"outputs"};
2216 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2217
2218 GraphDef output;
2219 ArithmeticOptimizer optimizer;
2220 EnableOnlyFoldMultipleIntoConv(&optimizer);
2221 OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
2222
2223 NodeMap node_map(&output);
2224
2225 using absl::StrCat;
2226 const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
2227 const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
2228 const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
2229
2230 const NodeDef* weights_node = node_map.GetNode(optimized_weights);
2231 const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
2232 const NodeDef* conv_node = node_map.GetNode("Conv2D");
2233 const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
2234
2235 ASSERT_NE(weights_node, nullptr);
2236 ASSERT_NE(weights_node_1, nullptr);
2237 ASSERT_NE(conv_node, nullptr);
2238 ASSERT_NE(conv_node_1, nullptr);
2239
2240 ASSERT_EQ(conv_node->input_size(), 2);
2241 ASSERT_EQ(conv_node_1->input_size(), 2);
2242 EXPECT_EQ(conv_node->input(1), weights_node->name());
2243 EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
2244 }
2245
TEST_F(ArithmeticOptimizerTest,CombineBitcasts)2246 TEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
2247 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2248 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_UINT8,
2249 ops::Placeholder::Shape({2, 3}));
2250 Output bc1 = ops::Bitcast(s.WithOpName("bc1"), inputs, DT_QINT8);
2251 Output bc2 = ops::Bitcast(s.WithOpName("bc2"), bc1, DT_INT8);
2252 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2253
2254 GrapplerItem item;
2255 item.fetch = {"outputs"};
2256 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2257
2258 auto x_t = GenerateRandomTensor<DT_UINT8>(TensorShape({2, 3}));
2259 item.feed = {{"inputs", x_t}};
2260 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2261 ASSERT_EQ(tensors_expected.size(), 1);
2262
2263 GraphDef output;
2264 ArithmeticOptimizer optimizer;
2265 EnableOnlyRemoveRedundantBitcast(&optimizer);
2266
2267 OptimizeAndPrune(&optimizer, &item, &output);
2268 NodeMap node_map(&output);
2269
2270 // Bitcasts combined into a single op and inputs redirected to updated Bitcast
2271 EXPECT_EQ(output.node_size(), 3);
2272 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 1);
2273 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "bc2"));
2274
2275 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2276 ASSERT_EQ(tensors.size(), 1);
2277 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2278 }
2279
TEST_F(ArithmeticOptimizerTest,CombineAndRemoveBitcasts)2280 TEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
2281 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2282 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2283 ops::Placeholder::Shape({2, 3}));
2284 Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
2285 Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
2286 Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
2287
2288 GrapplerItem item;
2289 item.fetch = {"outputs"};
2290 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2291
2292 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2293 item.feed = {{"inputs", x_t}};
2294 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2295 ASSERT_EQ(tensors_expected.size(), 1);
2296
2297 GraphDef output;
2298 ArithmeticOptimizer optimizer;
2299 EnableOnlyRemoveRedundantBitcast(&optimizer);
2300
2301 OptimizeAndPrune(&optimizer, &item, &output);
2302 NodeMap node_map(&output);
2303
2304 // Bitcasts removed and inputs redirected to outputs
2305 EXPECT_EQ(output.node_size(), 2);
2306 EXPECT_EQ(CountOpNodes(output, "Bitcast"), 0);
2307 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2308
2309 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2310 ASSERT_EQ(tensors.size(), 1);
2311 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2312 }
2313
TEST_F(ArithmeticOptimizerTest,RemoveRedundantCast)2314 TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
2315 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2316 Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_INT8,
2317 ops::Placeholder::Shape({2, 3}));
2318 Output cast = ops::Cast(s, inputs, DT_INT8);
2319 Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
2320
2321 GrapplerItem item;
2322 item.fetch = {"outputs"};
2323 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2324
2325 auto x_t = GenerateRandomTensor<DT_INT8>(TensorShape({2, 3}));
2326 item.feed = {{"inputs", x_t}};
2327 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
2328 ASSERT_EQ(tensors_expected.size(), 1);
2329
2330 GraphDef output;
2331 ArithmeticOptimizer optimizer;
2332 EnableOnlyRemoveRedundantCast(&optimizer);
2333
2334 OptimizeAndPrune(&optimizer, &item, &output);
2335 NodeMap node_map(&output);
2336
2337 // Cast removed and inputs redirected to outputs
2338 EXPECT_EQ(output.node_size(), 2);
2339 EXPECT_EQ(CountOpNodes(output, "Cast"), 0);
2340 EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
2341
2342 auto tensors = EvaluateNodes(output, item.fetch, item.feed);
2343 ASSERT_EQ(tensors.size(), 1);
2344 test::ExpectTensorEqual<int8>(tensors[0], tensors_expected[0]);
2345 }
2346
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfIdenticalShape)2347 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfIdenticalShape) {
2348 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2349 tensorflow::Scope sx = s.NewSubScope("x");
2350 tensorflow::Scope sy = s.NewSubScope("y");
2351
2352 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2353 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2354 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2355 auto add_bc = ops::Add(sx.WithOpName("Add_bc"), b, c);
2356 auto add_abc = ops::Add(sy.WithOpName("Add_abc"), a, add_bc);
2357
2358 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2359
2360 GrapplerItem item;
2361 item.fetch = {"outputs"};
2362 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2363
2364 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2365 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2366 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2367 std::vector<std::pair<string, Tensor>> feed = {
2368 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2369 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2370 ASSERT_EQ(tensors_expected.size(), 1);
2371
2372 GraphDef output;
2373 ArithmeticOptimizer optimizer;
2374 EnableOnlyAddToAddNCombining(&optimizer);
2375
2376 OptimizeAndPrune(&optimizer, &item, &output);
2377
2378 // We expect the following rewrite(s) to occur:
2379 //
2380 // +
2381 // / \
2382 // a + --> AddN(a, b, c)
2383 // / \
2384 // b c
2385 EXPECT_EQ(output.node_size(), 5);
2386
2387 NodeMap node_map(&output);
2388
2389 // check add tree was replaced with AddN
2390 const NodeDef* collapsed_add =
2391 node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2392 ASSERT_NE(collapsed_add, nullptr);
2393
2394 EXPECT_EQ(collapsed_add->op(), "AddN");
2395 ASSERT_EQ(collapsed_add->input_size(), 3);
2396 EXPECT_EQ(collapsed_add->input(0), "a");
2397 EXPECT_EQ(collapsed_add->input(1), "b");
2398 EXPECT_EQ(collapsed_add->input(2), "c");
2399
2400 // check output was re-wired to new node
2401 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2402 ASSERT_NE(updated_outputs, nullptr);
2403 ASSERT_EQ(updated_outputs->input_size(), 1);
2404 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2405
2406 auto tensors = EvaluateNodes(output, item.fetch, feed);
2407 ASSERT_EQ(tensors.size(), 1);
2408 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2409 }
2410
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMultiplePasses)2411 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
2412 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2413
2414 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2415 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2416 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2417 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2418 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2419
2420 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2421 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2422 auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
2423 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2424 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2425
2426 auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
2427 auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
2428
2429 GrapplerItem item;
2430 item.fetch = {"outputs"};
2431 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2432
2433 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2434 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2435 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2436 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2437 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2438 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2439 std::vector<std::pair<string, Tensor>> feed = {
2440 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2441 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2442 ASSERT_EQ(tensors_expected.size(), 1);
2443
2444 GraphDef output;
2445 ArithmeticOptimizer optimizer;
2446 EnableOnlyAddToAddNCombining(&optimizer);
2447
2448 OptimizeAndPrune(&optimizer, &item, &output);
2449
2450 // We expect the following rewrite(s) to occur:
2451 //
2452 // *
2453 // / \
2454 // + + *
2455 // / \ / \ / \
2456 // + c x + --> AddN(a, b, c) AddN(x, y, z))
2457 // / \ / \
2458 // a b y z
2459 EXPECT_EQ(output.node_size(), 10);
2460
2461 NodeMap node_map(&output);
2462
2463 // check left Add subtree replaced with AddN
2464 const NodeDef* collapsed_left =
2465 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2466 ASSERT_NE(collapsed_left, nullptr);
2467
2468 EXPECT_EQ(collapsed_left->op(), "AddN");
2469 ASSERT_EQ(collapsed_left->input_size(), 3);
2470 EXPECT_EQ(collapsed_left->input(0), "a");
2471 EXPECT_EQ(collapsed_left->input(1), "b");
2472 EXPECT_EQ(collapsed_left->input(2), "c");
2473
2474 // check right Add subtree replaced with AddN
2475 const NodeDef* collapsed_right =
2476 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz");
2477 ASSERT_NE(collapsed_right, nullptr);
2478
2479 EXPECT_EQ(collapsed_right->op(), "AddN");
2480 ASSERT_EQ(collapsed_right->input_size(), 3);
2481 EXPECT_EQ(collapsed_right->input(0), "x");
2482 EXPECT_EQ(collapsed_right->input(1), "y");
2483 EXPECT_EQ(collapsed_right->input(2), "z");
2484
2485 // check that Mul inputs re-wired to new Nodes
2486 const NodeDef* updated_mul = node_map.GetNode("Mul");
2487 ASSERT_NE(updated_mul, nullptr);
2488
2489 EXPECT_EQ(updated_mul->op(), "Mul");
2490 ASSERT_EQ(updated_mul->input_size(), 2);
2491 EXPECT_EQ(updated_mul->input(0), collapsed_left->name());
2492 EXPECT_EQ(updated_mul->input(1), collapsed_right->name());
2493
2494 auto tensors = EvaluateNodes(output, item.fetch, feed);
2495 ASSERT_EQ(tensors.size(), 1);
2496 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2497 }
2498
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddInputMultipleTimes)2499 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputMultipleTimes) {
2500 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2501
2502 auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
2503 auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
2504 auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
2505 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2506 auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
2507 auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
2508 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2509
2510 GrapplerItem item;
2511 item.fetch = {"outputs"};
2512 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2513
2514 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2515 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2516 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2517 std::vector<std::pair<string, Tensor>> feed = {
2518 {"a", a_t}, {"b", b_t}, {"c", c_t}};
2519 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2520 ASSERT_EQ(tensors_expected.size(), 1);
2521
2522 GraphDef output;
2523 ArithmeticOptimizer optimizer;
2524 EnableOnlyAddToAddNCombining(&optimizer);
2525
2526 OptimizeAndPrune(&optimizer, &item, &output);
2527
2528 // We expect the following rewrite(s) to occur:
2529 //
2530 // +
2531 // / \
2532 // + + --> AddN(a, b, b, c)
2533 // / \ / \ ^
2534 // a b c b added twice!
2535 EXPECT_EQ(output.node_size(), 5);
2536
2537 NodeMap node_map(&output);
2538
2539 // check Add tree replaced with AddN
2540 const NodeDef* collapsed_add =
2541 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_all");
2542 ASSERT_NE(collapsed_add, nullptr);
2543
2544 EXPECT_EQ(collapsed_add->op(), "AddN");
2545 ASSERT_EQ(collapsed_add->input_size(), 4);
2546 EXPECT_EQ(collapsed_add->input(0), "a");
2547 EXPECT_EQ(collapsed_add->input(1), "b");
2548 EXPECT_EQ(collapsed_add->input(2), "b");
2549 EXPECT_EQ(collapsed_add->input(3), "c");
2550
2551 auto tensors = EvaluateNodes(output, item.fetch, feed);
2552 ASSERT_EQ(tensors.size(), 1);
2553 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2554 }
2555
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteAddOpsOfSymbolicallyEqualShape)2556 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddOpsOfSymbolicallyEqualShape) {
2557 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2558
2559 // unknown input shape propagated symbolically through the graph
2560 auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
2561
2562 // [a, b, c] have symbolically equal shapes
2563 auto a = ops::Sqrt(s.WithOpName("a"), input);
2564 auto b = ops::Square(s.WithOpName("b"), input);
2565 auto c = ops::Round(s.WithOpName("c"), input);
2566
2567 // [add_ab, add_abc] shape must be inferred from inputs
2568 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2569 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2570
2571 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2572
2573 GrapplerItem item;
2574 item.fetch = {"outputs"};
2575 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2576
2577 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2578 std::vector<std::pair<string, Tensor>> feed = {{"input", x_t}};
2579 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2580 ASSERT_EQ(tensors_expected.size(), 1);
2581
2582 GraphDef output;
2583 ArithmeticOptimizer optimizer;
2584 EnableOnlyAddToAddNCombining(&optimizer);
2585
2586 OptimizeAndPrune(&optimizer, &item, &output);
2587
2588 // We expect the following rewrite(s) to occur:
2589 //
2590 // +
2591 // / \
2592 // + c --> AddN(a, b, c)
2593 // / \
2594 // a b
2595 EXPECT_EQ(output.node_size(), 6);
2596
2597 NodeMap node_map(&output);
2598
2599 // check add tree was replaced with AddN
2600 const NodeDef* collapsed_add =
2601 node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc");
2602 ASSERT_NE(collapsed_add, nullptr);
2603 EXPECT_EQ(collapsed_add->op(), "AddN");
2604 ASSERT_EQ(collapsed_add->input_size(), 3);
2605 EXPECT_EQ(collapsed_add->input(0), "a");
2606 EXPECT_EQ(collapsed_add->input(1), "b");
2607 EXPECT_EQ(collapsed_add->input(2), "c");
2608
2609 // check output was re-wired to new node
2610 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2611 ASSERT_NE(updated_outputs, nullptr);
2612 ASSERT_EQ(updated_outputs->input_size(), 1);
2613 EXPECT_EQ(updated_outputs->input(0), collapsed_add->name());
2614
2615 auto tensors = EvaluateNodes(output, item.fetch, feed);
2616 ASSERT_EQ(tensors.size(), 1);
2617 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2618 }
2619
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCast)2620 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCast) {
2621 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2622
2623 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
2624 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
2625 auto c = ops::Variable(s.WithOpName("c"), {32, 32, 32}, DT_FLOAT);
2626 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2627 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2628
2629 auto x = ops::Variable(s.WithOpName("x"), {32}, DT_FLOAT);
2630 auto y = ops::Variable(s.WithOpName("y"), {32, 32}, DT_FLOAT);
2631 auto z = ops::Variable(s.WithOpName("z"), {32, 32, 32}, DT_FLOAT);
2632 auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
2633 auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
2634
2635 auto add_all = ops::Add(s.WithOpName("AddAll"), add_abc, add_xyz);
2636 auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
2637
2638 GrapplerItem item;
2639 item.fetch = {"outputs"};
2640 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2641
2642 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2643 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2644 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2645 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
2646 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
2647 auto z_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32, 32}));
2648 std::vector<std::pair<string, Tensor>> feed = {
2649 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"x", x_t}, {"y", y_t}, {"z", z_t}};
2650 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2651 ASSERT_EQ(tensors_expected.size(), 1);
2652
2653 GraphDef output;
2654 ArithmeticOptimizer optimizer;
2655 EnableOnlyAddToAddNCombining(&optimizer);
2656
2657 OptimizeAndPrune(&optimizer, &item, &output);
2658
2659 // We expect the following rewrite(s) to occur:
2660 // 1) [a, x], [b, y], [c, z] - aggregate same shapes first
2661 // 2) Build an aggregation tree minimizing cost of broadcast
2662 //
2663 // + +
2664 // / \ / \
2665 // + + + AddN(c, z)
2666 // / \ / \ / \
2667 // + c x + --> AddN(a, x) AddN(b, y)
2668 // / \ / \
2669 // a b y z
2670 EXPECT_EQ(output.node_size(), 12);
2671 NodeMap node_map(&output);
2672
2673 // expected names of outer and inner nodes
2674 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_AddAll";
2675 string outer_0_add_name =
2676 "ArithmeticOptimizer/AddOpsRewrite_Internal_0_AddAll";
2677 string inner_0_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_AddAll";
2678 string inner_1_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_1_AddAll";
2679 string inner_2_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_2_AddAll";
2680
2681 // Add [a, x] first
2682 const NodeDef* add_ax_node = node_map.GetNode(inner_0_add_name);
2683 ASSERT_NE(add_ax_node, nullptr);
2684 EXPECT_EQ(add_ax_node->op(), "AddN");
2685 ASSERT_EQ(add_ax_node->input_size(), 2);
2686 EXPECT_EQ(add_ax_node->input(0), "a");
2687 EXPECT_EQ(add_ax_node->input(1), "x");
2688
2689 // Then add [b, y]
2690 const NodeDef* add_by_node = node_map.GetNode(inner_1_add_name);
2691 ASSERT_NE(add_by_node, nullptr);
2692 EXPECT_EQ(add_by_node->op(), "AddN");
2693 ASSERT_EQ(2, add_by_node->input_size());
2694 EXPECT_EQ(add_by_node->input(0), "b");
2695 EXPECT_EQ(add_by_node->input(1), "y");
2696
2697 // Then add [c, z]
2698 const NodeDef* add_cz_node = node_map.GetNode(inner_2_add_name);
2699 ASSERT_NE(add_cz_node, nullptr);
2700 EXPECT_EQ(add_cz_node->op(), "AddN");
2701 ASSERT_EQ(add_cz_node->input_size(), 2);
2702 EXPECT_EQ(add_cz_node->input(0), "c");
2703 EXPECT_EQ(add_cz_node->input(1), "z");
2704
2705 // Then add results together starting from smaller shapes [a, x] + [b, y]
2706 const NodeDef* outer_0_node = node_map.GetNode(outer_0_add_name);
2707 ASSERT_NE(outer_0_node, nullptr);
2708 EXPECT_EQ(outer_0_node->op(), "AddV2");
2709 ASSERT_EQ(outer_0_node->input_size(), 2);
2710 EXPECT_EQ(outer_0_node->input(0), inner_0_add_name);
2711 EXPECT_EQ(outer_0_node->input(1), inner_1_add_name);
2712
2713 // And finally top level Add node
2714 const NodeDef* outer_node = node_map.GetNode(outer_add_name);
2715 ASSERT_NE(outer_node, nullptr);
2716 EXPECT_EQ(outer_node->op(), "AddV2");
2717 ASSERT_EQ(outer_node->input_size(), 2);
2718 EXPECT_EQ(outer_node->input(0), outer_0_add_name);
2719 EXPECT_EQ(outer_node->input(1), inner_2_add_name);
2720
2721 // And outputs reading new top level Add node
2722 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2723 ASSERT_NE(updated_outputs, nullptr);
2724 ASSERT_EQ(updated_outputs->input_size(), 1);
2725 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2726
2727 auto tensors = EvaluateNodes(output, item.fetch, feed);
2728 ASSERT_EQ(tensors.size(), 1);
2729 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2730 }
2731
TEST_F(ArithmeticOptimizerTest,AddOpsRewriteMinimizeBCastWithSymbolicShapes)2732 TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMinimizeBCastWithSymbolicShapes) {
2733 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2734
2735 // We have a small input with one unknown dimension
2736 auto small = ops::Variable(s.WithOpName("small"), {-1, 1, 1}, DT_DOUBLE);
2737
2738 // And second input which is larger, but has the same unknown dimension
2739 // device spec prevents this node from rewriting
2740 auto d = "/device:CPU:0";
2741 auto v = ops::Variable(s.WithOpName("v"), {1, 32, 32}, DT_DOUBLE);
2742 auto large = ops::Add(s.WithOpName("large").WithDevice(d), small, v);
2743
2744 // [a, c] have {?, 1, 1} shape, [b] has {?, 32, 32}
2745 auto a = ops::Sqrt(s.WithOpName("a"), small);
2746 auto b = ops::Square(s.WithOpName("b"), large);
2747 auto c = ops::Round(s.WithOpName("c"), small);
2748
2749 // [add_ab, add_abc] shape must be inferred from inputs
2750 auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
2751 auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
2752
2753 auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
2754
2755 GrapplerItem item;
2756 item.fetch = {"outputs"};
2757 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2758
2759 auto s_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({8, 1, 1}));
2760 auto v_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({1, 32, 32}));
2761 std::vector<std::pair<string, Tensor>> feed = {{"small", s_t}, {"v", v_t}};
2762 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2763 ASSERT_EQ(tensors_expected.size(), 1);
2764
2765 GraphDef output;
2766 ArithmeticOptimizer optimizer;
2767 EnableOnlyAddToAddNCombining(&optimizer);
2768 OptimizeAndPrune(&optimizer, &item, &output);
2769
2770 // We expect the following rewrite(s) to occur: it's much cheaper to add small
2771 // tensors, and do the broadcast just once
2772 //
2773 // + +
2774 // / \ / \
2775 // + c --> + b
2776 // / \ / \
2777 // a b a c
2778 EXPECT_EQ(output.node_size(), 9);
2779 NodeMap node_map(&output);
2780
2781 // expected names of outer and inner nodes
2782 string outer_add_name = "ArithmeticOptimizer/AddOpsRewrite_Add_abc";
2783 string inner_add_name = "ArithmeticOptimizer/AddOpsRewrite_Leaf_0_Add_abc";
2784
2785 // outer Add node
2786 const NodeDef* outer_add = node_map.GetNode(outer_add_name);
2787 ASSERT_NE(outer_add, nullptr);
2788 EXPECT_EQ(outer_add->op(), "AddV2");
2789 ASSERT_EQ(outer_add->input_size(), 2);
2790 EXPECT_EQ(outer_add->input(0), inner_add_name);
2791 EXPECT_EQ(outer_add->input(1), "b");
2792
2793 // inner AddN node
2794 const NodeDef* inner_add = node_map.GetNode(inner_add_name);
2795 ASSERT_NE(inner_add, nullptr);
2796 ASSERT_EQ(inner_add->input_size(), 2);
2797 EXPECT_EQ(inner_add->input(0), "a");
2798 EXPECT_EQ(inner_add->input(1), "c");
2799
2800 // check output was re-wired to new node
2801 const NodeDef* updated_outputs = node_map.GetNode("outputs");
2802 ASSERT_NE(updated_outputs, nullptr);
2803 ASSERT_EQ(updated_outputs->input_size(), 1);
2804 EXPECT_EQ(updated_outputs->input(0), outer_add_name);
2805
2806 auto tensors = EvaluateNodes(output, item.fetch, feed);
2807 ASSERT_EQ(tensors.size(), 1);
2808 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
2809 }
2810
TEST_F(ArithmeticOptimizerTest,RemoveNegation)2811 TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
2812 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2813 auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
2814 auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
2815 Output neg_x = ops::Neg(s.WithOpName("Neg_x"), x);
2816 Output neg_y = ops::Neg(s.WithOpName("Neg_y"), y);
2817 Output add_x_y = ops::Add(s.WithOpName("Add_x_y"), x, y);
2818 Output add_negx_y = ops::Add(s.WithOpName("Add_negx_y"), neg_x, y);
2819 Output add_x_negy = ops::Add(s.WithOpName("Add_x_negy"), x, neg_y);
2820 Output add_negx_negy = ops::Add(s.WithOpName("Add_negx_negy"), neg_x, neg_y);
2821 Output sub_x_y = ops::Sub(s.WithOpName("Sub_x_y"), x, y);
2822 Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
2823 Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
2824 Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
2825 Output neg_x_with_dep = ops::Neg(
2826 s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
2827 Output add_negx_with_dep_y =
2828 ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
2829 auto add_all =
2830 ops::AddN(s.WithOpName("add_all"),
2831 {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
2832 sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
2833
2834 GrapplerItem item;
2835 item.fetch = {"add_all"};
2836 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2837
2838 auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2839 auto y_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
2840 std::vector<std::pair<string, Tensor>> feed = {{"x", x_t}, {"y", y_t}};
2841 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
2842 ASSERT_EQ(tensors_expected.size(), 1);
2843
2844 GraphDef output;
2845 ArithmeticOptimizer optimizer;
2846 EnableOnlyRemoveNegation(&optimizer);
2847 OptimizeTwice(&optimizer, &item, &output);
2848
2849 EXPECT_EQ(output.node_size(), item.graph.node_size());
2850 int found = 0;
2851 for (int i = 0; i < output.node_size(); ++i) {
2852 const NodeDef& node = output.node(i);
2853 if (node.name() == "Add_negx_y") {
2854 ++found;
2855 EXPECT_EQ(node.op(), "Sub");
2856 ASSERT_EQ(node.input_size(), 2);
2857 EXPECT_EQ(node.input(0), "y");
2858 EXPECT_EQ(node.input(1), "x");
2859 } else if (node.name() == "Add_x_negy") {
2860 ++found;
2861 EXPECT_EQ(node.op(), "Sub");
2862 ASSERT_EQ(node.input_size(), 2);
2863 EXPECT_EQ(node.input(0), "x");
2864 EXPECT_EQ(node.input(1), "y");
2865 } else if (node.name() == "Add_negx_negy") {
2866 ++found;
2867 EXPECT_EQ(node.op(), "Sub");
2868 ASSERT_EQ(node.input_size(), 2);
2869 EXPECT_EQ(node.input(0), "Neg_x");
2870 EXPECT_EQ(node.input(1), "y");
2871 } else if (node.name() == "Sub_x_negy") {
2872 ++found;
2873 EXPECT_EQ(node.op(), "AddV2");
2874 ASSERT_EQ(node.input_size(), 2);
2875 EXPECT_EQ(node.input(0), "x");
2876 EXPECT_EQ(node.input(1), "y");
2877 } else if (node.name() == "Sub_negx_negy") {
2878 ++found;
2879 EXPECT_EQ(node.op(), "Sub");
2880 ASSERT_EQ(node.input_size(), 2);
2881 EXPECT_EQ(node.input(0), "y");
2882 EXPECT_EQ(node.input(1), "x");
2883 } else if (node.name() == "Add_negx_with_dep_y") {
2884 ++found;
2885 EXPECT_EQ(node.op(), "Sub");
2886 ASSERT_EQ(node.input_size(), 3);
2887 EXPECT_EQ(node.input(0), "y");
2888 EXPECT_EQ(node.input(1), "x");
2889 EXPECT_EQ(node.input(2), "^Add_x_y");
2890 }
2891 }
2892 EXPECT_EQ(found, 6);
2893
2894 auto tensors = EvaluateNodes(output, item.fetch, feed);
2895 ASSERT_EQ(tensors.size(), 1);
2896 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2897 }
2898
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMul)2899 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
2900 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2901 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2902 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2903 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2904 Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
2905
2906 GrapplerItem item;
2907 item.fetch = {"output"};
2908 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2909 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2910 ASSERT_EQ(tensors_expected.size(), 1);
2911
2912 GraphDef output;
2913 ArithmeticOptimizer optimizer;
2914 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2915 OptimizeAndPrune(&optimizer, &item, &output);
2916 auto tensors = EvaluateNodes(output, item.fetch);
2917 ASSERT_EQ(tensors.size(), 1);
2918
2919 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2920 EXPECT_EQ(output.node_size(), item.graph.node_size());
2921 for (int i = 0; i < output.node_size(); ++i) {
2922 const NodeDef& node = output.node(i);
2923 if (node.name() == "output") {
2924 EXPECT_EQ(node.op(), "Mul");
2925 ASSERT_EQ(node.input_size(), 2);
2926 EXPECT_EQ(node.input(0), "x");
2927 EXPECT_EQ(node.input(1), "sqrt_y");
2928 } else if (node.name() == "sqrt_y") {
2929 EXPECT_EQ(node.op(), "Rsqrt");
2930 ASSERT_EQ(node.input_size(), 1);
2931 EXPECT_EQ(node.input(0), "y");
2932 }
2933 }
2934 }
2935
TEST_F(ArithmeticOptimizerTest,DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode)2936 TEST_F(ArithmeticOptimizerTest, DoNotConvertSqrtDivToRsqrtMulDivisorFetchNode) {
2937 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2938 Output floats = ops::Const(s.WithOpName("floats"),
2939 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
2940 Output output0 = ops::Sqrt(s.WithOpName("output0"), floats);
2941 Output const1 = ops::Const(s.WithOpName("const1"), 1.0f, {3});
2942 Output mul1 = ops::Multiply(s.WithOpName("mul1"), const1, 0.5f);
2943 Output grad = ops::Div(s.WithOpName("grad"), mul1, output0);
2944
2945 GrapplerItem item;
2946 item.fetch = {"grad", "output0"};
2947 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2948 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2949 ASSERT_EQ(tensors_expected.size(), 2);
2950
2951 GraphDef output;
2952 ArithmeticOptimizer optimizer;
2953 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2954 OptimizeAndPrune(&optimizer, &item, &output);
2955 auto tensors = EvaluateNodes(output, item.fetch);
2956 ASSERT_EQ(tensors.size(), 2);
2957
2958 for (int i = 0; i < tensors.size(); i++) {
2959 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
2960 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
2961 }
2962 EXPECT_EQ(output.node_size(), item.graph.node_size());
2963 for (int i = 0; i < output.node_size(); ++i) {
2964 const NodeDef& node = output.node(i);
2965 if (node.name() == "grad") {
2966 EXPECT_EQ(node.op(), "Div");
2967 ASSERT_EQ(node.input_size(), 2);
2968 EXPECT_EQ(node.input(0), "mul1");
2969 EXPECT_EQ(node.input(1), "output0");
2970 } else if (node.name() == "output0") {
2971 EXPECT_EQ(node.op(), "Sqrt");
2972 ASSERT_EQ(node.input_size(), 1);
2973 EXPECT_EQ(node.input(0), "floats");
2974 }
2975 }
2976 }
2977
TEST_F(ArithmeticOptimizerTest,ConvertSqrtDivToRsqrtMulExcludeFloorDiv)2978 TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMulExcludeFloorDiv) {
2979 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2980 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2981 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
2982 Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
2983 Output div_x_sqrt_y = ops::FloorDiv(s.WithOpName("output"), x, sqrt_y);
2984
2985 GrapplerItem item;
2986 item.fetch = {"output"};
2987 TF_CHECK_OK(s.ToGraphDef(&item.graph));
2988 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
2989 ASSERT_EQ(tensors_expected.size(), 1);
2990
2991 GraphDef output;
2992 ArithmeticOptimizer optimizer;
2993 EnableOnlySqrtDivToRsqrtMul(&optimizer);
2994 OptimizeAndPrune(&optimizer, &item, &output);
2995 auto tensors = EvaluateNodes(output, item.fetch);
2996 ASSERT_EQ(tensors.size(), 1);
2997
2998 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
2999 EXPECT_EQ(output.node_size(), item.graph.node_size());
3000 for (int i = 0; i < output.node_size(); ++i) {
3001 const NodeDef& node = output.node(i);
3002 if (node.name() == "output") {
3003 EXPECT_EQ(node.op(), "FloorDiv");
3004 ASSERT_EQ(node.input_size(), 2);
3005 EXPECT_EQ(node.input(0), "x");
3006 EXPECT_EQ(node.input(1), "sqrt_y");
3007 } else if (node.name() == "sqrt_y") {
3008 EXPECT_EQ(node.op(), "Sqrt");
3009 ASSERT_EQ(node.input_size(), 1);
3010 EXPECT_EQ(node.input(0), "y");
3011 }
3012 }
3013 }
3014
TEST_F(ArithmeticOptimizerTest,FuseSquaredDiff)3015 TEST_F(ArithmeticOptimizerTest, FuseSquaredDiff) {
3016 for (bool is_complex : {false, true}) {
3017 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3018 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3019 Output y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3020 Output complex_x = ops::Complex(s.WithOpName("complex_x"), x, x);
3021 Output complex_y = ops::Complex(s.WithOpName("complex_y"), y, y);
3022 Output sub_x_y =
3023 is_complex ? ops::Sub(s.WithOpName("sub_x_y"), complex_x, complex_y)
3024 : ops::Sub(s.WithOpName("sub_x_y"), x, y);
3025 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3026
3027 GrapplerItem item;
3028 item.fetch = {"output"};
3029 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3030 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3031 ASSERT_EQ(tensors_expected.size(), 1);
3032
3033 GraphDef output;
3034 ArithmeticOptimizer optimizer;
3035 EnableOnlyFuseSquaredDiff(&optimizer);
3036 OptimizeAndPrune(&optimizer, &item, &output);
3037 const auto tensors = EvaluateNodes(output, item.fetch);
3038 ASSERT_EQ(tensors.size(), 1);
3039
3040 if (is_complex) {
3041 test::ExpectTensorNear<std::complex<float>>(tensors[0],
3042 tensors_expected[0], 1e-6);
3043 EXPECT_EQ(output.node_size(), item.graph.node_size());
3044 } else {
3045 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3046 // The two unused Complex nodes should get pruned.
3047 EXPECT_EQ(output.node_size(), item.graph.node_size() - 2);
3048 }
3049 for (int i = 0; i < output.node_size(); ++i) {
3050 const NodeDef& node = output.node(i);
3051 if (node.name() == "output") {
3052 EXPECT_EQ(node.op(), is_complex ? "Square" : "Identity");
3053 ASSERT_EQ(node.input_size(), 1);
3054 EXPECT_EQ(node.input(0), "sub_x_y");
3055 } else if (node.name() == "sub_x_y") {
3056 EXPECT_EQ(node.op(), is_complex ? "Sub" : "SquaredDifference");
3057 ASSERT_EQ(node.input_size(), 2);
3058 EXPECT_EQ(node.input(0), is_complex ? "complex_x" : "x");
3059 EXPECT_EQ(node.input(1), is_complex ? "complex_y" : "y");
3060 }
3061 }
3062 }
3063 }
3064
TEST_F(ArithmeticOptimizerTest,DoNotFuseSquaredDiffFetchNode)3065 TEST_F(ArithmeticOptimizerTest, DoNotFuseSquaredDiffFetchNode) {
3066 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3067 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3068 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3069 Output sub_x_y = ops::Sub(s.WithOpName("sub_x_y"), x, y);
3070 Output square_sub_x_y = ops::Square(s.WithOpName("output"), sub_x_y);
3071
3072 GrapplerItem item;
3073 item.fetch = {"output", "sub_x_y"};
3074 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3075 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3076 ASSERT_EQ(tensors_expected.size(), 2);
3077
3078 GraphDef output;
3079 ArithmeticOptimizer optimizer;
3080 EnableOnlyFuseSquaredDiff(&optimizer);
3081 OptimizeAndPrune(&optimizer, &item, &output);
3082 const auto tensors = EvaluateNodes(output, item.fetch);
3083 ASSERT_EQ(tensors.size(), 2);
3084
3085 for (int i = 0; i < tensors.size(); i++) {
3086 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3087 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3088 }
3089 EXPECT_EQ(output.node_size(), item.graph.node_size());
3090 for (int i = 0; i < output.node_size(); ++i) {
3091 const NodeDef& node = output.node(i);
3092 if (node.name() == "output") {
3093 EXPECT_EQ(node.op(), "Square");
3094 ASSERT_EQ(node.input_size(), 1);
3095 EXPECT_EQ(node.input(0), "sub_x_y");
3096 } else if (node.name() == "sub_x_y") {
3097 EXPECT_EQ(node.op(), "Sub");
3098 ASSERT_EQ(node.input_size(), 2);
3099 EXPECT_EQ(node.input(0), "x");
3100 EXPECT_EQ(node.input(1), "y");
3101 }
3102 }
3103 }
3104
TEST_F(ArithmeticOptimizerTest,ConvertLogSoftmax)3105 TEST_F(ArithmeticOptimizerTest, ConvertLogSoftmax) {
3106 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3107 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3108 Output softmax = ops::Softmax(s.WithOpName("softmax"), x);
3109 Output logsoftmax = ops::Log(s.WithOpName("output"), softmax);
3110
3111 GrapplerItem item;
3112 item.fetch = {"output"};
3113 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3114 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3115 ASSERT_EQ(tensors_expected.size(), 1);
3116
3117 GraphDef output;
3118 ArithmeticOptimizer optimizer;
3119 EnableOnlyLogSoftmax(&optimizer);
3120 OptimizeAndPrune(&optimizer, &item, &output);
3121 const auto tensors = EvaluateNodes(output, item.fetch);
3122 ASSERT_EQ(tensors.size(), 1);
3123
3124 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3125 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
3126 for (int i = 0; i < output.node_size(); ++i) {
3127 const NodeDef& node = output.node(i);
3128 if (node.name() == "output") {
3129 EXPECT_EQ(node.op(), "LogSoftmax");
3130 ASSERT_EQ(node.input_size(), 1);
3131 EXPECT_EQ(node.input(0), "x");
3132 }
3133 }
3134 }
3135
TEST_F(ArithmeticOptimizerTest,DoNotConvertLogSoftmaxArgFetchNode)3136 TEST_F(ArithmeticOptimizerTest, DoNotConvertLogSoftmaxArgFetchNode) {
3137 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3138 Output floats = ops::Const(s.WithOpName("floats"),
3139 {0.7423212f, 0.19757693f, 0.53124744f}, {1, 3});
3140 Output softmax = ops::Softmax(s.WithOpName("softmax"), floats);
3141 Output final_output = ops::Log(s.WithOpName("final_output"), softmax);
3142
3143 GrapplerItem item;
3144 item.fetch = {"softmax", "final_output"};
3145 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3146 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3147 ASSERT_EQ(tensors_expected.size(), 2);
3148
3149 GraphDef output;
3150 ArithmeticOptimizer optimizer;
3151 EnableOnlyLogSoftmax(&optimizer);
3152 OptimizeTwice(&optimizer, &item, &output);
3153 const auto tensors = EvaluateNodes(output, item.fetch);
3154 ASSERT_EQ(tensors.size(), 2);
3155
3156 // Should be a NoOp since we are not allowed to change the output of fetch
3157 // nodes.
3158 VerifyGraphsMatch(item.graph, output, __LINE__);
3159
3160 for (int i = 0; i < tensors.size(); i++) {
3161 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3162 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3163 }
3164 }
3165
TEST_F(ArithmeticOptimizerTest,ConvertPow)3166 TEST_F(ArithmeticOptimizerTest, ConvertPow) {
3167 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3168 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3169 auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
3170 auto y3 = ops::Const(s.WithOpName("y3"), {3.0f, 3.0f}, {1, 2});
3171 auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
3172 auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
3173 auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
3174 auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
3175 auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
3176 auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
3177 auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
3178 auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
3179 auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
3180 Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
3181 Output out3 =
3182 ops::Pow(s.WithOpName("out3").WithDevice("/device:CPU:0"), x, y3);
3183 Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
3184 Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
3185 Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
3186 Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
3187 Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
3188 Output out = ops::Pow(s.WithOpName("out"), x, y);
3189 Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
3190 Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
3191
3192 GrapplerItem item;
3193 item.fetch = {"out2", "out3", "out1", "out.5", "out0",
3194 "out_.5", "out_1", "out", "out_bcast1", "out_bcast2"};
3195 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3196 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3197 ASSERT_EQ(tensors_expected.size(), 10);
3198
3199 GraphDef got;
3200 ArithmeticOptimizer optimizer;
3201 EnableOnlyConvertPow(&optimizer);
3202 OptimizeAndPrune(&optimizer, &item, &got);
3203 auto tensors = EvaluateNodes(got, item.fetch);
3204 ASSERT_EQ(tensors.size(), 10);
3205
3206 for (int i = 0; i < tensors.size(); ++i) {
3207 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3208 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3209 }
3210
3211 GraphDef want;
3212 AddNode("x", "Const", {}, {}, &want);
3213 AddNode("y", "Const", {}, {}, &want);
3214 AddNode("z", "Const", {}, {}, &want);
3215 AddNode("ones", "Const", {}, {}, &want);
3216 AddNode("zeros", "Const", {}, {}, &want);
3217 AddNode("out2", "Square", {"x"}, {}, &want);
3218 AddNode("ArithmeticOptimizer/ConvertPow__inner_out3", "Square", {"x"}, {},
3219 &want)
3220 ->set_device("/device:CPU:0");
3221 AddNode("out3", "Mul", {"x", "ArithmeticOptimizer/ConvertPow__inner_out3"},
3222 {}, &want)
3223 ->set_device("/device:CPU:0");
3224 AddNode("out1", "Identity", {"x"}, {}, &want);
3225 AddNode("out.5", "Sqrt", {"x"}, {}, &want);
3226 AddNode("out0", "Const", {AsControlDependency("x")}, {}, &want);
3227 AddNode("out_.5", "Rsqrt", {"x"}, {}, &want);
3228 AddNode("out_1", "Reciprocal", {"x"}, {}, &want);
3229 AddNode("out", "Pow", {"x", "y"}, {}, &want);
3230 AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
3231 AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
3232
3233 CompareGraphs(want, got);
3234 }
3235
TEST_F(ArithmeticOptimizerTest,Log1p)3236 TEST_F(ArithmeticOptimizerTest, Log1p) {
3237 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3238
3239 auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
3240 auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
3241 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3242 auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
3243 auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
3244 Output out1 = ops::Log(s.WithOpName("out1"), a12);
3245 Output out2 = ops::Log(s.WithOpName("out2"), a23);
3246
3247 GrapplerItem item;
3248 item.fetch = {"out1", "out2"};
3249 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3250 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3251 ASSERT_EQ(tensors_expected.size(), 2);
3252
3253 GraphDef got;
3254 ArithmeticOptimizer optimizer;
3255 EnableOnlyLog1p(&optimizer);
3256 OptimizeAndPrune(&optimizer, &item, &got);
3257 auto tensors = EvaluateNodes(got, item.fetch);
3258 ASSERT_EQ(tensors.size(), 2);
3259
3260 for (int i = 0; i < 2; ++i) {
3261 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3262 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3263 }
3264
3265 GraphDef want;
3266 AddNode("x2", "Const", {}, {}, &want);
3267 AddNode("x3", "Const", {}, {}, &want);
3268 AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
3269 AddNode("out1", "Log1p", {"x2", AsControlDependency("x3")}, {}, &want);
3270 AddNode("out2", "Log", {"a23"}, {}, &want);
3271
3272 CompareGraphs(want, got);
3273 }
3274
TEST_F(ArithmeticOptimizerTest,Expm1)3275 TEST_F(ArithmeticOptimizerTest, Expm1) {
3276 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3277
3278 auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
3279 auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
3280 auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
3281 auto exp1 = ops::Exp(s.WithOpName("exp1").WithControlDependencies(x3), x1);
3282 Output out1 = ops::Sub(s.WithOpName("out1"), exp1, x2);
3283 Output out2 = ops::Sub(s.WithOpName("out2"), exp1, x3);
3284
3285 GrapplerItem item;
3286 item.fetch = {"out1", "out2"};
3287 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3288 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3289 ASSERT_EQ(tensors_expected.size(), 2);
3290
3291 GraphDef got;
3292 ArithmeticOptimizer optimizer;
3293 EnableOnlyExpm1(&optimizer);
3294 OptimizeAndPrune(&optimizer, &item, &got);
3295 auto tensors = EvaluateNodes(got, item.fetch);
3296 ASSERT_EQ(tensors.size(), 2);
3297
3298 for (int i = 0; i < 2; ++i) {
3299 EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
3300 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3301 }
3302
3303 GraphDef want;
3304 AddNode("x1", "Const", {}, {}, &want);
3305 AddNode("x3", "Const", {}, {}, &want);
3306 AddNode("exp1", "Exp", {"x1", AsControlDependency("x3")}, {}, &want);
3307 AddNode("out1", "Expm1", {"x1", AsControlDependency("x3")}, {}, &want);
3308 AddNode("out2", "Sub", {"exp1", "x3"}, {}, &want);
3309
3310 CompareGraphs(want, got);
3311 }
3312
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_SimpleSwap)3313 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
3314 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3315
3316 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3317 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
3318 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3319
3320 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3321 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3322
3323 auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
3324
3325 GrapplerItem item;
3326 item.fetch = {"outputs"};
3327 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3328
3329 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3330 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3331 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3332 std::vector<std::pair<string, Tensor>> feed = {
3333 {"a", a_t}, {"b", b_t}, {"c", c_t}};
3334 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3335 ASSERT_EQ(tensors_expected.size(), 1);
3336
3337 GraphDef output;
3338 ArithmeticOptimizer optimizer;
3339 EnableOnlyMinimizeBroadcasts(&optimizer);
3340
3341 OptimizeAndPrune(&optimizer, &item, &output);
3342
3343 // We expect the following rewrite(s) to occur:
3344 //
3345 // * *
3346 // / \ / \
3347 // * c --> * b
3348 // / \ / \
3349 // a b a c
3350 NodeMap node_map(&output);
3351
3352 const NodeDef* mul1_node = node_map.GetNode("mul1");
3353 ASSERT_NE(mul1_node, nullptr);
3354 ASSERT_EQ(mul1_node->input_size(), 2);
3355 EXPECT_EQ(mul1_node->input(0), "a");
3356 EXPECT_EQ(mul1_node->input(1), "c");
3357
3358 const NodeDef* mul2_node = node_map.GetNode("mul2");
3359 ASSERT_NE(mul2_node, nullptr);
3360 ASSERT_EQ(mul2_node->input_size(), 2);
3361 EXPECT_EQ(mul2_node->input(0), "mul1");
3362 EXPECT_EQ(mul2_node->input(1), "b");
3363
3364 auto tensors = EvaluateNodes(output, item.fetch, feed);
3365 ASSERT_EQ(tensors.size(), 1);
3366 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3367 }
3368
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_FlattenTallGraph)3369 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
3370 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3371
3372 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_DOUBLE);
3373 auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_DOUBLE);
3374 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_DOUBLE);
3375 auto d = ops::Variable(s.WithOpName("d"), {32}, DT_DOUBLE);
3376 auto e = ops::Variable(s.WithOpName("e"), {32}, DT_DOUBLE);
3377
3378 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3379 auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
3380 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
3381 auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
3382
3383 auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
3384
3385 GrapplerItem item;
3386 item.fetch = {"outputs"};
3387 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3388
3389 auto a_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3390 auto b_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32, 32}));
3391 auto c_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3392 auto d_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3393 auto e_t = GenerateRandomTensor<DT_DOUBLE>(TensorShape({32}));
3394 std::vector<std::pair<string, Tensor>> feed = {
3395 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"d", d_t}, {"e", e_t}};
3396 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3397 ASSERT_EQ(tensors_expected.size(), 1);
3398
3399 GraphDef output;
3400 ArithmeticOptimizer optimizer;
3401 EnableOnlyMinimizeBroadcasts(&optimizer);
3402
3403 OptimizeAndPrune(&optimizer, &item, &output);
3404
3405 // We expect the following rewrite(s) to occur: Graph is "flattened" and
3406 // largest shape pushed to the top.
3407 //
3408 // *
3409 // / \
3410 // * e *
3411 // / \ / \
3412 // * d * b
3413 // / \ / \
3414 // * c --> * *
3415 // / \ / \ / \
3416 // a b a c d e
3417 NodeMap node_map(&output);
3418
3419 const NodeDef* mul1_node = node_map.GetNode("mul1");
3420 ASSERT_NE(mul1_node, nullptr);
3421 ASSERT_EQ(mul1_node->input_size(), 2);
3422 EXPECT_EQ(mul1_node->input(0), "a");
3423 EXPECT_EQ(mul1_node->input(1), "c");
3424
3425 const NodeDef* mul2_node = node_map.GetNode("mul2");
3426 ASSERT_NE(mul2_node, nullptr);
3427 ASSERT_EQ(mul2_node->input_size(), 2);
3428 EXPECT_EQ(mul2_node->input(0), "d");
3429 EXPECT_EQ(mul2_node->input(1), "e");
3430
3431 const NodeDef* mul3_node = node_map.GetNode("mul3");
3432 ASSERT_NE(mul3_node, nullptr);
3433 ASSERT_EQ(mul3_node->input_size(), 2);
3434 EXPECT_EQ(mul3_node->input(0), "mul1");
3435 EXPECT_EQ(mul3_node->input(1), "mul2");
3436
3437 const NodeDef* mul4_node = node_map.GetNode("mul4");
3438 ASSERT_NE(mul4_node, nullptr);
3439 ASSERT_EQ(mul4_node->input_size(), 2);
3440 EXPECT_EQ(mul4_node->input(0), "mul3");
3441 EXPECT_EQ(mul4_node->input(1), "b");
3442
3443 auto tensors = EvaluateNodes(output, item.fetch, feed);
3444 ASSERT_EQ(tensors.size(), 1);
3445 test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
3446 }
3447
TEST_F(ArithmeticOptimizerTest,MinimizeBroadcasts_BuildTreeUp)3448 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
3449 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3450
3451 // [a, b, c] - scalars, [d] - matrix
3452 auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
3453 auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
3454 auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
3455 auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
3456
3457 auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
3458 auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
3459 auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
3460
3461 auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
3462
3463 GrapplerItem item;
3464 item.fetch = {"outputs"};
3465 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3466
3467 auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3468 auto b_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3469 auto c_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32}));
3470 auto d_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({32, 32}));
3471 std::vector<std::pair<string, Tensor>> feed = {
3472 {"a", a_t}, {"b", b_t}, {"c", c_t}, {"D", d_t}};
3473 auto tensors_expected = EvaluateNodes(item.graph, item.fetch, feed);
3474 ASSERT_EQ(tensors_expected.size(), 1);
3475
3476 GraphDef output;
3477 ArithmeticOptimizer optimizer;
3478 EnableOnlyMinimizeBroadcasts(&optimizer);
3479
3480 OptimizeAndPrune(&optimizer, &item, &output);
3481
3482 // We expect the following rewrite(s) to occur:
3483 //
3484 // *
3485 // / \
3486 // * * D
3487 // / \ / \
3488 // * * -> * c
3489 // / \ / \ / \
3490 // a b c D a b
3491 NodeMap node_map(&output);
3492
3493 const NodeDef* mul1_node = node_map.GetNode("mul2");
3494 ASSERT_NE(mul1_node, nullptr);
3495 ASSERT_EQ(mul1_node->input_size(), 2);
3496 EXPECT_EQ(mul1_node->input(0), "a");
3497 EXPECT_EQ(mul1_node->input(1), "b");
3498
3499 const NodeDef* mul2_node = node_map.GetNode("mul1");
3500 ASSERT_NE(mul2_node, nullptr);
3501 ASSERT_EQ(mul2_node->input_size(), 2);
3502 EXPECT_EQ(mul2_node->input(0), "mul2");
3503 EXPECT_EQ(mul2_node->input(1), "c");
3504
3505 const NodeDef* mul3_node = node_map.GetNode("mul3");
3506 ASSERT_NE(mul3_node, nullptr);
3507 ASSERT_EQ(mul3_node->input_size(), 2);
3508 EXPECT_EQ(mul3_node->input(0), "D");
3509 EXPECT_EQ(mul3_node->input(1), "mul1");
3510
3511 auto tensors = EvaluateNodes(output, item.fetch, feed);
3512 ASSERT_EQ(tensors.size(), 1);
3513 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3514 }
3515
TEST_F(ArithmeticOptimizerTest,DoNotHoistReluFromConcat)3516 TEST_F(ArithmeticOptimizerTest, DoNotHoistReluFromConcat) {
3517 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3518 Output weights1 = ops::Const(s.WithOpName("weights1"),
3519 Input::Initializer(1.0f, {5, 5, 3, 4}));
3520 Output weights2 = ops::Const(s.WithOpName("weights2"),
3521 Input::Initializer(2.0f, {5, 5, 3, 4}));
3522 Output biases =
3523 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
3524 Output axis = ops::Const(s.WithOpName("axis"), 3, {});
3525 Output input = ops::Const(s.WithOpName("input"),
3526 Input::Initializer(1.0f, {1, 28, 28, 3}));
3527 Output branch1 =
3528 ops::Conv2D(s.WithOpName("conv1"), input, weights1, {1, 1, 1, 1}, "SAME");
3529 branch1 = ops::BiasAdd(s.WithOpName("biasadd1"), branch1, biases);
3530 branch1 = ops::Relu(s.WithOpName("relu1"), branch1);
3531 Output branch2 =
3532 ops::Conv2D(s.WithOpName("conv2"), input, weights2, {1, 1, 1, 1}, "SAME");
3533 branch2 = ops::BiasAdd(s.WithOpName("biasadd2"), branch2, biases);
3534 branch2 = ops::Relu(s.WithOpName("relu2"), branch2);
3535 Output concat = ops::Concat(s.WithOpName("concat"), {branch1, branch2}, axis);
3536 Output output = ops::Identity(s.WithOpName("output"), concat);
3537
3538 GrapplerItem item;
3539 item.fetch = {"output"};
3540 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3541
3542 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3543
3544 GraphDef new_graph;
3545 ArithmeticOptimizer optimizer;
3546 OptimizeAndPrune(&optimizer, &item, &new_graph);
3547
3548 // Verify that the two Relus are not hoisted.
3549 EXPECT_EQ(CountOpNodes(new_graph, "Relu"), 2);
3550
3551 auto tensors = EvaluateNodes(new_graph, item.fetch);
3552 for (int i = 0; i < item.fetch.size(); ++i) {
3553 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3554 }
3555 }
3556
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryFromConcat)3557 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
3558 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3559 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3560 Output b = ops::Const(s.WithOpName("b"), 1.0f, {32});
3561 Output c = ops::Const(s.WithOpName("c"), 42.0f, {32});
3562 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3563 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3564 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3565 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3566 // Test case with chains of length 1.
3567 // Rewrites
3568 // Concat({Exp(a), Exp(b), Exp(c)})
3569 // into
3570 // Exp(Concat({a, b, c})).
3571 Output sin_a =
3572 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
3573 Output exp_a =
3574 ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
3575 Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
3576 Output exp_c =
3577 ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
3578 Output concat =
3579 ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
3580 Output id = ops::Identity(s.WithOpName("id"), concat);
3581
3582 // Test case with chains of length 2.
3583 // Rewrites
3584 // Concat({Cos(Exp(a)), Cos(Exp(b)), Cos(Exp(c))})
3585 // into
3586 // Cos(Exp(Concat({a, b, c}))).
3587 Output exp_a2 =
3588 ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
3589 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
3590 Output exp_c2 =
3591 ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
3592 Output cos_exp_a2 = ops::Cos(
3593 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
3594 Output cos_exp_b2 = ops::Cos(
3595 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3596 Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
3597 Output concat2 = ops::Concat(s.WithOpName("concat2"),
3598 {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
3599 Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
3600 GrapplerItem item;
3601 item.fetch = {"id", "id2"};
3602 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3603
3604 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3605
3606 GraphDef output;
3607 ArithmeticOptimizer optimizer;
3608 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3609 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3610
3611 int found = 0;
3612 for (const NodeDef& node : output.node()) {
3613 if (node.name() == "concat") {
3614 ASSERT_EQ(node.input_size(), 4);
3615 EXPECT_EQ(node.input(0), "sin_a");
3616 EXPECT_EQ(node.input(1), "b");
3617 EXPECT_EQ(node.input(2), "c");
3618 EXPECT_EQ(node.input(3), "axis");
3619 found++;
3620 }
3621 if (node.name() == "exp_a") {
3622 ASSERT_EQ(node.input_size(), 1);
3623 EXPECT_EQ(node.input(0), "concat");
3624 found++;
3625 }
3626 if (node.name() == "id") {
3627 ASSERT_EQ(node.input_size(), 1);
3628 EXPECT_EQ(node.input(0), "exp_a");
3629 found++;
3630 }
3631
3632 if (node.name() == "concat2") {
3633 ASSERT_EQ(node.input_size(), 4);
3634 EXPECT_EQ(node.input(0), "sin_a");
3635 EXPECT_EQ(node.input(1), "b");
3636 EXPECT_EQ(node.input(2), "c");
3637 EXPECT_EQ(node.input(3), "axis");
3638 found++;
3639 }
3640 if (node.name() == "exp_a2") {
3641 ASSERT_EQ(node.input_size(), 1);
3642 EXPECT_EQ(node.input(0), "concat2");
3643 found++;
3644 }
3645 if (node.name() == "cos_exp_a2") {
3646 ASSERT_EQ(node.input_size(), 1);
3647 EXPECT_EQ(node.input(0), "exp_a2");
3648 found++;
3649 }
3650 if (node.name() == "id2") {
3651 ASSERT_EQ(node.input_size(), 1);
3652 EXPECT_EQ(node.input(0), "cos_exp_a2");
3653 found++;
3654 }
3655 }
3656 EXPECT_EQ(found, 7);
3657
3658 auto tensors = EvaluateNodes(output, item.fetch);
3659 ASSERT_EQ(tensors.size(), tensors_expected.size());
3660 EXPECT_EQ(tensors.size(), item.fetch.size());
3661 for (int i = 0; i < item.fetch.size(); ++i) {
3662 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3663 }
3664 }
3665
TEST_F(ArithmeticOptimizerTest,HoistCWiseUnaryIntoSplit)3666 TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
3667 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3668 Output x = ops::Const(s.WithOpName("x"), 3.1415f, {32});
3669 Output axis = ops::Const(s.WithOpName("axis"), 0, {});
3670 Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
3671 Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
3672 Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
3673 // Test case with chains of length 1.
3674 // Rewrites
3675 // [Sin(y) for y in Split(x)]
3676 // into
3677 // [y for y in Split(Sin(x))].
3678 ops::Split split1(s.WithOpName("split1"), axis, x, 2);
3679 Output sin_a =
3680 ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl1), split1[0]);
3681 Output id_a = ops::Identity(s.WithOpName("id_a"), sin_a);
3682 Output sin_b = ops::Sin(s.WithOpName("sin_b"), split1[1]);
3683 Output exp_b = ops::Exp(s.WithOpName("exp_b"), sin_b);
3684 Output id_b = ops::Identity(s.WithOpName("id_b"), exp_b);
3685
3686 // Test case with SplitV and chains of length 2.
3687 // Rewrites
3688 // [Cos(Exp(y)) for y in Split(x)]
3689 // into
3690 // [y for y in Split(Cos(Exp(x)))].
3691 Output size_splits2 = ops::Const(s.WithOpName("size_splits2"), {20, 12}, {2});
3692 ops::SplitV split2(s.WithOpName("split2"), x, size_splits2, axis, 2);
3693 Output exp_a2 = ops::Exp(
3694 s.WithOpName("exp_a2").WithControlDependencies(ctrl1), split2[0]);
3695 Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), split2[1]);
3696 Output cos_exp_a2 = ops::Cos(
3697 s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl2), exp_a2);
3698 Output cos_exp_b2 = ops::Cos(
3699 s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
3700 Output id_a2 = ops::Identity(s.WithOpName("id_a2"), cos_exp_a2);
3701 Output id_b2 = ops::Identity(s.WithOpName("id_b2"), cos_exp_b2);
3702
3703 GrapplerItem item;
3704 item.fetch = {"id_a", "id_b", "id_a2", "id_b2"};
3705 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3706
3707 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3708
3709 GraphDef output;
3710 ArithmeticOptimizer optimizer;
3711 EnableOnlyHoistCWiseUnaryChains(&optimizer);
3712 OptimizeTwiceAndPrune(&optimizer, &item, &output);
3713
3714 int found = 0;
3715 for (const NodeDef& node : output.node()) {
3716 // The following 6 nodes should be pruned.
3717 EXPECT_NE(node.name(), "sin_a");
3718 EXPECT_NE(node.name(), "sin_b");
3719 EXPECT_NE(node.name(), "exp_a2");
3720 EXPECT_NE(node.name(), "exp_b2");
3721 EXPECT_NE(node.name(), "cos_exp_a2");
3722 EXPECT_NE(node.name(), "cos_exp_b2");
3723
3724 if (node.name() == "split1") {
3725 ASSERT_EQ(node.input_size(), 2);
3726 EXPECT_EQ(node.input(0), "axis");
3727 EXPECT_EQ(node.input(1), "ArithmeticOptimizer/_sin_a_split1");
3728 found++;
3729 }
3730 if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
3731 EXPECT_EQ(node.op(), "Sin");
3732 ASSERT_EQ(node.input_size(), 1);
3733 EXPECT_EQ(node.input(0), "x");
3734 found++;
3735 }
3736 if (node.name() == "id_a") {
3737 ASSERT_EQ(node.input_size(), 1);
3738 EXPECT_EQ(node.input(0), "split1");
3739 found++;
3740 }
3741 if (node.name() == "exp_b") {
3742 ASSERT_EQ(node.input_size(), 1);
3743 EXPECT_EQ(node.input(0), "split1:1");
3744 found++;
3745 }
3746 if (node.name() == "id_b") {
3747 ASSERT_EQ(node.input_size(), 1);
3748 EXPECT_EQ(node.input(0), "exp_b");
3749 found++;
3750 }
3751 if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
3752 EXPECT_EQ(node.op(), "Exp");
3753 ASSERT_EQ(node.input_size(), 1);
3754 EXPECT_EQ(node.input(0), "x");
3755 found++;
3756 }
3757 if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
3758 EXPECT_EQ(node.op(), "Cos");
3759 ASSERT_EQ(node.input_size(), 1);
3760 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_exp_a2_split2");
3761 found++;
3762 }
3763 if (node.name() == "split2") {
3764 ASSERT_EQ(node.input_size(), 3);
3765 EXPECT_EQ(node.input(0), "ArithmeticOptimizer/_cos_exp_a2_split2");
3766 EXPECT_EQ(node.input(1), "size_splits2");
3767 EXPECT_EQ(node.input(2), "axis");
3768 found++;
3769 }
3770 if (node.name() == "id_a2") {
3771 ASSERT_EQ(node.input_size(), 1);
3772 EXPECT_EQ(node.input(0), "split2");
3773 found++;
3774 }
3775 if (node.name() == "id_b2") {
3776 ASSERT_EQ(node.input_size(), 1);
3777 EXPECT_EQ(node.input(0), "split2:1");
3778 found++;
3779 }
3780 }
3781 EXPECT_EQ(found, 10);
3782
3783 auto tensors = EvaluateNodes(output, item.fetch);
3784 ASSERT_EQ(tensors.size(), tensors_expected.size());
3785 EXPECT_EQ(tensors.size(), item.fetch.size());
3786 for (int i = 0; i < item.fetch.size(); ++i) {
3787 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3788 }
3789 }
3790
TEST_F(ArithmeticOptimizerTest,RemoveIdempotent)3791 TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
3792 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3793 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3794 Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
3795 Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
3796 Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
3797 Output id1 = ops::Identity(s.WithOpName("id1"), a);
3798 Output id2 = ops::Identity(s.WithOpName("id2"), id1);
3799 Output out2 = ops::Identity(s.WithOpName("out2"), id2);
3800 GrapplerItem item;
3801 item.fetch = {"out1", "out2"};
3802 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3803
3804 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3805
3806 GraphDef output;
3807 ArithmeticOptimizer optimizer;
3808 EnableOnlyRemoveIdempotent(&optimizer);
3809 OptimizeTwice(&optimizer, &item, &output);
3810
3811 EXPECT_EQ(7, output.node_size());
3812 int found = 0;
3813 for (const NodeDef& node : output.node()) {
3814 if (node.name() == "out1") {
3815 ASSERT_EQ(node.input_size(), 1);
3816 EXPECT_EQ(node.input(0), "sn1");
3817 found++;
3818 } else if (node.name() == "out2") {
3819 ASSERT_EQ(node.input_size(), 1);
3820 EXPECT_EQ(node.input(0), "id1");
3821 found++;
3822 } else if (node.name() == "sn1") {
3823 ASSERT_EQ(node.input_size(), 1);
3824 EXPECT_EQ(node.input(0), "a");
3825 found++;
3826 }
3827 }
3828 EXPECT_EQ(found, 3);
3829
3830 auto tensors = EvaluateNodes(output, item.fetch);
3831 ASSERT_EQ(tensors.size(), tensors_expected.size());
3832 EXPECT_EQ(tensors.size(), item.fetch.size());
3833 for (int i = 0; i < item.fetch.size(); ++i) {
3834 test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
3835 }
3836 }
3837
TEST_F(ArithmeticOptimizerTest,RemoveLogicalNot)3838 TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
3839 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3840 Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
3841 Output b = ops::Const(s.WithOpName("b"), -3.14f, {32});
3842 Output eq = ops::Equal(s.WithOpName("eq"), a, b);
3843 Output neq = ops::NotEqual(s.WithOpName("neq"), a, b);
3844 Output lt = ops::Less(s.WithOpName("lt"), a, b);
3845 Output le = ops::LessEqual(s.WithOpName("le"), a, b);
3846 Output gt = ops::Greater(s.WithOpName("gt"), a, b);
3847 Output ge = ops::GreaterEqual(s.WithOpName("ge"), a, b);
3848 // not_eq is reserved
3849 Output not_eq1 = ops::LogicalNot(s.WithOpName("not_eq1"), eq);
3850 Output not_neq = ops::LogicalNot(s.WithOpName("not_neq"), neq);
3851 Output not_lt = ops::LogicalNot(s.WithOpName("not_lt"), lt);
3852 Output not_le = ops::LogicalNot(s.WithOpName("not_le"), le);
3853 Output not_gt = ops::LogicalNot(s.WithOpName("not_gt"), gt);
3854 Output not_ge = ops::LogicalNot(s.WithOpName("not_ge"), ge);
3855 Output id_not_eq = ops::Identity(s.WithOpName("id_not_eq"), not_eq1);
3856 Output id_not_neq = ops::Identity(s.WithOpName("id_not_neq"), not_neq);
3857 Output id_not_lt = ops::Identity(s.WithOpName("id_not_lt"), not_lt);
3858 Output id_not_le = ops::Identity(s.WithOpName("id_not_le"), not_le);
3859 Output id_not_gt = ops::Identity(s.WithOpName("id_not_gt"), not_gt);
3860 Output id_not_ge = ops::Identity(s.WithOpName("id_not_ge"), not_ge);
3861
3862 GrapplerItem item;
3863 item.fetch = {"id_not_eq", "id_not_neq", "id_not_lt",
3864 "id_not_le", "id_not_gt", "id_not_ge"};
3865 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3866
3867 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3868
3869 GraphDef output;
3870 ArithmeticOptimizer optimizer;
3871 EnableOnlyRemoveLogicalNot(&optimizer);
3872 OptimizeTwice(&optimizer, &item, &output);
3873
3874 int found = 0;
3875 for (const NodeDef& node : output.node()) {
3876 if (node.name() == "id_not_eq") {
3877 ASSERT_EQ(node.input_size(), 1);
3878 EXPECT_EQ(node.input(0), "eq");
3879 ++found;
3880 }
3881 if (node.name() == "id_not_neq") {
3882 ASSERT_EQ(node.input_size(), 1);
3883 EXPECT_EQ(node.input(0), "neq");
3884 ++found;
3885 }
3886 if (node.name() == "id_not_lt") {
3887 ASSERT_EQ(node.input_size(), 1);
3888 EXPECT_EQ(node.input(0), "lt");
3889 ++found;
3890 }
3891 if (node.name() == "id_not_le") {
3892 ASSERT_EQ(node.input_size(), 1);
3893 EXPECT_EQ(node.input(0), "le");
3894 ++found;
3895 }
3896 if (node.name() == "id_not_gt") {
3897 ASSERT_EQ(node.input_size(), 1);
3898 EXPECT_EQ(node.input(0), "gt");
3899 ++found;
3900 }
3901 if (node.name() == "id_not_ge") {
3902 ASSERT_EQ(node.input_size(), 1);
3903 EXPECT_EQ(node.input(0), "ge");
3904 ++found;
3905 }
3906
3907 if (node.name() == "eq") {
3908 EXPECT_EQ(node.op(), "NotEqual");
3909 ++found;
3910 }
3911 if (node.name() == "neq") {
3912 EXPECT_EQ(node.op(), "Equal");
3913 ++found;
3914 }
3915 if (node.name() == "lt") {
3916 EXPECT_EQ(node.op(), "GreaterEqual");
3917 ++found;
3918 }
3919 if (node.name() == "le") {
3920 EXPECT_EQ(node.op(), "Greater");
3921 ++found;
3922 }
3923 if (node.name() == "gt") {
3924 EXPECT_EQ(node.op(), "LessEqual");
3925 ++found;
3926 }
3927 if (node.name() == "ge") {
3928 EXPECT_EQ(node.op(), "Less");
3929 ++found;
3930 }
3931 }
3932 EXPECT_EQ(found, 12);
3933
3934 auto tensors = EvaluateNodes(output, item.fetch);
3935 ASSERT_EQ(tensors.size(), tensors_expected.size());
3936 EXPECT_EQ(tensors.size(), item.fetch.size());
3937 for (int i = 0; i < item.fetch.size(); ++i) {
3938 test::ExpectTensorEqual<bool>(tensors[i], tensors_expected[i]);
3939 }
3940 }
3941
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWise)3942 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
3943 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3944 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3945 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3946 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
3947 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
3948
3949 GrapplerItem item;
3950 item.fetch = {"final_out"};
3951 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3952 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3953 ASSERT_EQ(tensors_expected.size(), 1);
3954
3955 GraphDef output;
3956 ArithmeticOptimizer optimizer;
3957 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3958 OptimizeAndPrune(&optimizer, &item, &output);
3959 auto tensors = EvaluateNodes(output, item.fetch);
3960 ASSERT_EQ(tensors.size(), 1);
3961
3962 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
3963 EXPECT_EQ(output.node_size(), item.graph.node_size());
3964 // Check if the inputs are switched
3965 int required_node_count = 0;
3966 for (int i = 0; i < output.node_size(); ++i) {
3967 const NodeDef& node = output.node(i);
3968 if (node.name() == "sqrt") {
3969 EXPECT_EQ(node.op(), "Sqrt");
3970 ASSERT_EQ(node.input_size(), 1);
3971 EXPECT_EQ(node.input(0), "reduce_max");
3972 ++required_node_count;
3973 } else if (node.name() == "reduce_max") {
3974 EXPECT_EQ(node.op(), "Max");
3975 ASSERT_EQ(node.input_size(), 2);
3976 EXPECT_EQ(node.input(0), "x");
3977 ++required_node_count;
3978 }
3979 }
3980 EXPECT_EQ(required_node_count, 2);
3981 }
3982
TEST_F(ArithmeticOptimizerTest,OptimizeArgMaxOrArgMinOfMonotonicElementWise)3983 TEST_F(ArithmeticOptimizerTest, OptimizeArgMaxOrArgMinOfMonotonicElementWise) {
3984 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3985 const auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
3986 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
3987 Output arg_max = ops::ArgMax(s.WithOpName("arg_max"), sqrt, 1);
3988 Output final_out = ops::Identity(s.WithOpName("final_out"), arg_max);
3989
3990 GrapplerItem item;
3991 item.fetch = {"final_out"};
3992 TF_CHECK_OK(s.ToGraphDef(&item.graph));
3993 const auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3994 ASSERT_EQ(tensors_expected.size(), 1);
3995
3996 GraphDef output;
3997 ArithmeticOptimizer optimizer;
3998 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3999 OptimizeAndPrune(&optimizer, &item, &output);
4000 const auto tensors = EvaluateNodes(output, item.fetch);
4001 ASSERT_EQ(tensors.size(), 1);
4002
4003 test::ExpectTensorEqual<int64_t>(tensors[0], tensors_expected[0]);
4004 EXPECT_EQ(output.node_size(), item.graph.node_size() - 1);
4005 // Check if the inputs are switched
4006 int required_node_count = 0;
4007 for (int i = 0; i < output.node_size(); ++i) {
4008 const NodeDef& node = output.node(i);
4009 if (node.name() == "final_out") {
4010 EXPECT_EQ(node.op(), "Identity");
4011 ASSERT_EQ(node.input_size(), 1);
4012 EXPECT_EQ(node.input(0), "arg_max");
4013 ++required_node_count;
4014 } else if (node.name() == "arg_max") {
4015 EXPECT_EQ(node.op(), "ArgMax");
4016 ASSERT_EQ(node.input_size(), 2);
4017 EXPECT_EQ(node.input(0), "x");
4018 ++required_node_count;
4019 }
4020 }
4021 EXPECT_EQ(required_node_count, 2);
4022 }
4023
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode)4024 TEST_F(ArithmeticOptimizerTest,
4025 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNode) {
4026 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4027 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4028 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4029 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
4030 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4031
4032 GrapplerItem item;
4033 item.fetch = {"sqrt", "final_out"};
4034 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4035 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4036 EXPECT_EQ(tensors_expected.size(), 2);
4037
4038 GraphDef output;
4039 ArithmeticOptimizer optimizer;
4040 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4041 OptimizeTwice(&optimizer, &item, &output);
4042
4043 // Should be a NoOp since we are not allowed to change the output of fetch
4044 // nodes.
4045 VerifyGraphsMatch(item.graph, output, __LINE__);
4046 }
4047
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction)4048 TEST_F(ArithmeticOptimizerTest,
4049 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
4050 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4051 auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
4052 Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
4053 Output y = ops::Neg(s.WithOpName("y"), reshape);
4054 Output z = ops::Max(s.WithOpName("z"), y, {0});
4055
4056 GrapplerItem item;
4057 item.fetch = {"z"};
4058 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4059 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4060 ASSERT_EQ(tensors_expected.size(), 1);
4061
4062 GraphDef output;
4063 ArithmeticOptimizer optimizer;
4064 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4065 OptimizeTwice(&optimizer, &item, &output);
4066
4067 // Should be a NoOp since we are not allowed to change the output of fetch
4068 // nodes.
4069 VerifyGraphsMatch(item.graph, output, __LINE__);
4070
4071 auto tensors = EvaluateNodes(output, item.fetch);
4072 ASSERT_EQ(tensors.size(), 1);
4073 test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
4074 test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
4075 }
4076
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeSegmentMaxOrMinOps)4077 TEST_F(ArithmeticOptimizerTest,
4078 OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeSegmentMaxOrMinOps) {
4079 constexpr absl::string_view kSegmentMaxOpName = "SegmentMax";
4080 constexpr absl::string_view kUnsortedSegmentMaxOpName = "UnsortedSegmentMax";
4081 constexpr absl::string_view kSegmentMinOpName = "SegmentMin";
4082 constexpr absl::string_view kUnsortedSegmentMinOpName = "UnsortedSegmentMin";
4083 constexpr absl::string_view segment_max_or_min_op_names[] = {
4084 kSegmentMaxOpName, kUnsortedSegmentMaxOpName, kSegmentMinOpName,
4085 kUnsortedSegmentMinOpName};
4086 for (const absl::string_view segment_op_name : segment_max_or_min_op_names) {
4087 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4088 Input x = ops::Const(s.WithOpName("x"), {-1.0f, 2.0f, -3.0f, 4.0f}, {2, 2});
4089 Input segment_ids = ops::Const(s.WithOpName("x"), {0, 2}, {2});
4090 Output relu = ops::Relu(s.WithOpName("relu"), x);
4091 Output segment_op;
4092 if (segment_op_name == kSegmentMaxOpName) {
4093 segment_op =
4094 ops::SegmentMax(s.WithOpName(segment_op_name), relu, segment_ids);
4095 } else if (segment_op_name == kUnsortedSegmentMaxOpName) {
4096 segment_op = ops::UnsortedSegmentMax(s.WithOpName(segment_op_name), relu,
4097 segment_ids, 3);
4098 } else if (segment_op_name == kSegmentMinOpName) {
4099 segment_op =
4100 ops::SegmentMin(s.WithOpName(segment_op_name), relu, segment_ids);
4101 } else {
4102 segment_op = ops::UnsortedSegmentMin(s.WithOpName(segment_op_name), relu,
4103 segment_ids, 3);
4104 }
4105 Output final_out = ops::Identity(s.WithOpName("final_out"), segment_op);
4106
4107 GrapplerItem item;
4108 item.fetch = {"relu", "final_out"};
4109 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4110 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4111 EXPECT_EQ(tensors_expected.size(), 2);
4112
4113 GraphDef output;
4114 ArithmeticOptimizer optimizer;
4115 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4116 OptimizeTwice(&optimizer, &item, &output);
4117
4118 VerifyGraphsMatch(item.graph, output, __LINE__);
4119 }
4120 }
4121
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing)4122 TEST_F(ArithmeticOptimizerTest,
4123 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
4124 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4125 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4126 Output neg = ops::Neg(s.WithOpName("neg"), x);
4127 Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
4128 Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
4129
4130 GrapplerItem item;
4131 item.fetch = {"final_out"};
4132 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4133 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4134 ASSERT_EQ(tensors_expected.size(), 1);
4135
4136 GraphDef output;
4137 ArithmeticOptimizer optimizer;
4138 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4139 OptimizeAndPrune(&optimizer, &item, &output);
4140 auto tensors = EvaluateNodes(output, item.fetch);
4141 ASSERT_EQ(tensors.size(), 1);
4142
4143 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4144 EXPECT_EQ(output.node_size(), item.graph.node_size());
4145 // Check if the inputs are switched
4146 int required_node_count = 0;
4147 for (int i = 0; i < output.node_size(); ++i) {
4148 const NodeDef& node = output.node(i);
4149 if (node.name() == "neg") {
4150 EXPECT_EQ(node.op(), "Neg");
4151 ASSERT_EQ(node.input_size(), 1);
4152 EXPECT_EQ(node.input(0), "reduce_max");
4153 ++required_node_count;
4154 } else if (node.name() == "reduce_max") {
4155 EXPECT_EQ(node.op(), "Min");
4156 ASSERT_EQ(node.input_size(), 2);
4157 EXPECT_EQ(node.input(0), "x");
4158 ++required_node_count;
4159 }
4160 }
4161 EXPECT_EQ(2, required_node_count);
4162 }
4163
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool)4164 TEST_F(ArithmeticOptimizerTest,
4165 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
4166 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4167 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4168 Output neg = ops::Neg(s.WithOpName("neg"), x);
4169 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
4170 {1, 2, 2, 1}, "VALID");
4171
4172 GrapplerItem item;
4173 item.fetch = {"max_pool"};
4174 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4175 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4176 ASSERT_EQ(tensors_expected.size(), 1);
4177
4178 GraphDef output;
4179 ArithmeticOptimizer optimizer;
4180 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4181 OptimizeTwice(&optimizer, &item, &output);
4182
4183 // Should be a NoOp
4184 VerifyGraphsMatch(item.graph, output, __LINE__);
4185
4186 auto tensors = EvaluateNodes(output, item.fetch);
4187 ASSERT_EQ(tensors.size(), 1);
4188 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4189 }
4190
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool)4191 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicBiasAddReluMaxPool) {
4192 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4193 Output weights = ops::Const(s.WithOpName("weights"),
4194 Input::Initializer(1.0f, {5, 5, 3, 4}));
4195 Output biases =
4196 ops::Const(s.WithOpName("biases"), Input::Initializer(2.0f, {4}));
4197 Output input = ops::Const(s.WithOpName("input"),
4198 Input::Initializer(1.0f, {1, 28, 28, 3}));
4199 Output output =
4200 ops::Conv2D(s.WithOpName("conv"), input, weights, {1, 1, 1, 1}, "SAME");
4201 output = ops::BiasAdd(s.WithOpName("biasadd"), output, biases);
4202 output = ops::Relu(s.WithOpName("relu"), output);
4203 output = ops::MaxPool(s.WithOpName("max_pool"), output, {1, 2, 2, 1},
4204 {1, 2, 2, 1}, "VALID");
4205 output = ops::Identity(s.WithOpName("output"), output);
4206
4207 GrapplerItem item;
4208 item.fetch = {"output"};
4209 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4210 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4211 ASSERT_EQ(tensors_expected.size(), 1);
4212
4213 GraphDef new_graph;
4214 ArithmeticOptimizer optimizer;
4215 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4216 OptimizeTwice(&optimizer, &item, &new_graph);
4217
4218 // Should be a NoOp
4219 VerifyGraphsMatch(item.graph, new_graph, __LINE__);
4220
4221 auto tensors = EvaluateNodes(new_graph, item.fetch);
4222 ASSERT_EQ(tensors.size(), 1);
4223 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4224 }
4225
TEST_F(ArithmeticOptimizerTest,OptimizeMaxOrMinOfMonotonicElementWiseMaxPool)4226 TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
4227 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4228 auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
4229 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4230 Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
4231 {1, 2, 2, 1}, "VALID");
4232 Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
4233
4234 GrapplerItem item;
4235 item.fetch = {"final_out"};
4236 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4237 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4238 ASSERT_EQ(tensors_expected.size(), 1);
4239
4240 GraphDef output;
4241 ArithmeticOptimizer optimizer;
4242 EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
4243 OptimizeAndPrune(&optimizer, &item, &output);
4244 auto tensors = EvaluateNodes(output, item.fetch);
4245 ASSERT_EQ(tensors.size(), 1);
4246
4247 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4248 EXPECT_EQ(output.node_size(), item.graph.node_size());
4249 // Check if the inputs are switched
4250 int required_node_count = 0;
4251 for (int i = 0; i < output.node_size(); ++i) {
4252 const NodeDef& node = output.node(i);
4253 if (node.name() == "sqrt") {
4254 EXPECT_EQ(node.op(), "Sqrt");
4255 ASSERT_EQ(node.input_size(), 1);
4256 EXPECT_EQ(node.input(0), "max_pool");
4257 ++required_node_count;
4258 } else if (node.name() == "max_pool") {
4259 EXPECT_EQ(node.op(), "MaxPool");
4260 ASSERT_EQ(node.input_size(), 1);
4261 EXPECT_EQ(node.input(0), "x");
4262 ++required_node_count;
4263 }
4264 }
4265 EXPECT_EQ(required_node_count, 2);
4266 }
4267
TEST_F(ArithmeticOptimizerTest,UnaryOpsComposition)4268 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
4269 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4270
4271 auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4272 Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
4273 Output log = ops::Log(s.WithOpName("log"), sqrt);
4274 Output relu = ops::Relu(s.WithOpName("relu"), log);
4275 Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
4276
4277 GrapplerItem item;
4278 item.fetch = {"final_out"};
4279 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4280
4281 // Place all nodes on CPU.
4282 for (int i = 0; i < item.graph.node_size(); ++i) {
4283 item.graph.mutable_node(i)->set_device("/device:CPU:0");
4284 }
4285
4286 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4287 ASSERT_EQ(tensors_expected.size(), 1);
4288
4289 GraphDef output;
4290 ArithmeticOptimizer optimizer;
4291 EnableOnlyUnaryOpsComposition(&optimizer);
4292 OptimizeAndPrune(&optimizer, &item, &output);
4293
4294 EXPECT_EQ(output.node_size(), 3);
4295
4296 // Check that Sqrt/Log/Relu were replaced with a single op.
4297 int required_node_count = 0;
4298 for (int i = 0; i < output.node_size(); ++i) {
4299 const NodeDef& node = output.node(i);
4300 if (node.name() == "final_out") {
4301 EXPECT_EQ(node.op(), "Identity");
4302 ASSERT_EQ(node.input_size(), 1);
4303 EXPECT_EQ(node.input(0), "relu/unary_ops_composition");
4304 ++required_node_count;
4305 } else if (node.name() == "relu/unary_ops_composition") {
4306 EXPECT_EQ(node.op(), "_UnaryOpsComposition");
4307 ASSERT_EQ(node.input_size(), 1);
4308 EXPECT_EQ(node.input(0), "x");
4309
4310 auto op_names = node.attr().at("op_names").list().s();
4311 ASSERT_EQ(op_names.size(), 3);
4312 EXPECT_EQ(op_names[0], "Sqrt");
4313 EXPECT_EQ(op_names[1], "Log");
4314 EXPECT_EQ(op_names[2], "Relu");
4315 ++required_node_count;
4316 }
4317 }
4318 EXPECT_EQ(required_node_count, 2);
4319
4320 auto tensors = EvaluateNodes(output, item.fetch);
4321 ASSERT_EQ(tensors.size(), 1);
4322 test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
4323 }
4324
TEST_F(ArithmeticOptimizerTest,RemoveStackStridedSliceSameAxis)4325 TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
4326 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4327 auto a_in =
4328 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4329 auto b_in =
4330 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4331 auto c_in =
4332 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4333 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4334 PartialTensorShape({-1, -1}));
4335 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4336 PartialTensorShape({-1, -1}));
4337 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4338 PartialTensorShape({-1, -1}));
4339 // stacked = tf.stack((a, b, c), axis=1).
4340 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4341 auto stacked =
4342 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4343 ops::Stack::Axis(1));
4344 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4345 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4346 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4347 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4348 auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
4349 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4350 auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
4351 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4352 auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
4353 auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
4354 auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
4355
4356 // stacked[:, 0]
4357 using SS = ops::StridedSlice;
4358 auto pa_slice = ops::Identity(
4359 s.WithOpName("pa_slice_out"),
4360 SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
4361 SS::BeginMask(0b0101) // 5
4362 .EllipsisMask(0)
4363 .EndMask(0b0101) // 5
4364 .NewAxisMask(0)
4365 .ShrinkAxisMask(0b0010))); // 2
4366
4367 // stacked[:, 1]
4368 auto pb_slice = ops::Identity(
4369 s.WithOpName("pb_slice_out"),
4370 SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
4371 SS::BeginMask(0b0101) // 5
4372 .EllipsisMask(0)
4373 .EndMask(0b0101) // 5
4374 .NewAxisMask(0)
4375 .ShrinkAxisMask(0b0010))); // 2
4376
4377 // stacked[:, 2]
4378 auto pc_slice = ops::Identity(
4379 s.WithOpName("pc_slice_out"),
4380 SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
4381 SS::BeginMask(0b0101) // 5
4382 .EllipsisMask(0)
4383 .EndMask(0b0101) // 5
4384 .NewAxisMask(0)
4385 .ShrinkAxisMask(0b0010))); // 2
4386
4387 // stacked[:, 0:1, :]
4388 auto pa_slice_01 = ops::Identity(
4389 s.WithOpName("pa_slice_01_out"),
4390 SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
4391 SS::BeginMask(0b0101) // 5
4392 .EllipsisMask(0)
4393 .EndMask(0b0101) // 5
4394 .NewAxisMask(0)
4395 .ShrinkAxisMask(0)));
4396
4397 // stacked[:, :1, :]
4398 auto pa_slice_to1 = ops::Identity(
4399 s.WithOpName("pa_slice_to1_out"),
4400 SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
4401 SS::BeginMask(0b0111) // 7
4402 .EllipsisMask(0)
4403 .EndMask(0b0101) // 5
4404 .NewAxisMask(0)
4405 .ShrinkAxisMask(0)));
4406
4407 // stacked[:, 1:2, :]
4408 auto pb_slice_12 = ops::Identity(
4409 s.WithOpName("pb_slice_12_out"),
4410 SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
4411 SS::BeginMask(0b0101) // 5
4412 .EllipsisMask(0)
4413 .EndMask(0b0101) // 5
4414 .NewAxisMask(0)
4415 .ShrinkAxisMask(0)));
4416
4417 // stacked[:, 2:, :].
4418 auto pc_slice_2to = ops::Identity(
4419 s.WithOpName("pc_slice_2to_out"),
4420 SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
4421 SS::BeginMask(0b0101) // 5
4422 .EllipsisMask(0)
4423 .EndMask(0b0111) // 7
4424 .NewAxisMask(0)
4425 .ShrinkAxisMask(0)));
4426
4427 GrapplerItem item;
4428 item.fetch = {"a",
4429 "b",
4430 "c",
4431 "pa_slice_out",
4432 "pb_slice_out",
4433 "pc_slice_out",
4434 "expanded_a",
4435 "expanded_b",
4436 "expanded_c",
4437 "pa_slice_01_out",
4438 "pa_slice_to1_out",
4439 "pb_slice_12_out",
4440 "pc_slice_2to_out"};
4441 enum FetchItem {
4442 fA,
4443 fB,
4444 fC,
4445 fASliceOut,
4446 fBSliceOut,
4447 fCSliceOut,
4448 fExpandedA,
4449 fExpandedB,
4450 fExpandedC,
4451 fASlice01Out,
4452 fASliceTo1Out,
4453 fBSlice12Out,
4454 fCSlice2ToOut,
4455 };
4456 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4457 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4458
4459 // stacked[:, 0, :] == a.
4460 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4461 tensors_expected[fA]);
4462 // stacked[:, 1, :] == b.
4463 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4464 tensors_expected[fB]);
4465 // stacked[:, 2, :] == c.
4466 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4467 tensors_expected[fC]);
4468
4469 // stacked[:, 0:1, :] == expand_dims(a, 1).
4470 test::ExpectTensorEqual<float>(tensors_expected[fASlice01Out],
4471 tensors_expected[fExpandedA]);
4472
4473 // stacked[:, :1, :] == expand_dims(a, 1).
4474 test::ExpectTensorEqual<float>(tensors_expected[fASliceTo1Out],
4475 tensors_expected[fExpandedA]);
4476
4477 // stacked[:, 1:2, :] == expand_dims(b, 1).
4478 test::ExpectTensorEqual<float>(tensors_expected[fBSlice12Out],
4479 tensors_expected[fExpandedB]);
4480 // stacked[:, 2:, :] == expand_dims(c, 1).
4481 test::ExpectTensorEqual<float>(tensors_expected[fCSlice2ToOut],
4482 tensors_expected[fExpandedC]);
4483
4484 GraphDef output;
4485 ArithmeticOptimizer optimizer;
4486 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4487 OptimizeAndPrune(&optimizer, &item, &output);
4488
4489 for (const auto& node : output.node()) {
4490 if (node.name() == "pa_slice_out") {
4491 ASSERT_EQ(node.input_size(), 1);
4492 EXPECT_EQ(node.input(0), "a");
4493 } else if (node.name() == "pb_slice_out") {
4494 ASSERT_EQ(node.input_size(), 1);
4495 EXPECT_EQ(node.input(0), "b");
4496 } else if (node.name() == "pc_slice_out") {
4497 ASSERT_EQ(node.input_size(), 1);
4498 EXPECT_EQ(node.input(0), "c");
4499 } else if (str_util::EndsWith(node.name(), "_out")) {
4500 ASSERT_EQ(node.input_size(), 1);
4501 EXPECT_EQ(
4502 absl::StrCat(node.input(0), "_out"),
4503 absl::StrCat("ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
4504 node.name()));
4505 }
4506 }
4507
4508 auto tensors = EvaluateNodes(output, item.fetch);
4509
4510 // stacked[:, 0, :] == a.
4511 test::ExpectTensorEqual<float>(tensors[fASliceOut], tensors_expected[fA]);
4512
4513 // stacked[:, 1, :] == b.
4514 test::ExpectTensorEqual<float>(tensors[fBSliceOut], tensors_expected[fB]);
4515 // stacked[:, 2, :] == c.
4516 test::ExpectTensorEqual<float>(tensors[fCSliceOut], tensors_expected[fC]);
4517
4518 // stacked[:, 0:1, :] == expand_dims(a, 1).
4519 test::ExpectTensorEqual<float>(tensors[fASlice01Out],
4520 tensors_expected[fExpandedA]);
4521
4522 // stacked[:, :1, :] == expand_dims(a, 1).
4523 test::ExpectTensorEqual<float>(tensors[fASliceTo1Out],
4524 tensors_expected[fExpandedA]);
4525
4526 // stacked[:, 1:2, :] == expand_dims(b, 1).
4527 test::ExpectTensorEqual<float>(tensors[fBSlice12Out],
4528 tensors_expected[fExpandedB]);
4529 // stacked[:, 2:, :] == expand_dims(c, 1).
4530 test::ExpectTensorEqual<float>(tensors[fCSlice2ToOut],
4531 tensors_expected[fExpandedC]);
4532 }
4533
TEST_F(ArithmeticOptimizerTest,RemoveStackSimpleSliceSameAxis)4534 TEST_F(ArithmeticOptimizerTest, RemoveStackSimpleSliceSameAxis) {
4535 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4536 auto a_in =
4537 ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4538 auto b_in =
4539 ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
4540 auto c_in =
4541 ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
4542 auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
4543 PartialTensorShape({-1, -1}));
4544 auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
4545 PartialTensorShape({-1, -1}));
4546 auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
4547 PartialTensorShape({-1, -1}));
4548 // stacked = tf.stack((a, b, c), axis=1).
4549 // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
4550 auto stacked =
4551 ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
4552 ops::Stack::Axis(1));
4553 auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
4554 auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
4555 auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
4556 auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
4557 auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
4558 auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
4559 auto sizes_to_end = ops::Const(s.WithOpName("size"), {-1, 1, -1}, {3});
4560
4561 // stacked[:, 0:1, :]
4562 auto pa_slice = ops::Identity(
4563 s.WithOpName("pa_slice_out"),
4564 ops::Slice(s.WithOpName("pa_slice"), stacked, begin_a, sizes_to_end));
4565
4566 // stacked[:, 1:2, :]
4567 auto pb_slice = ops::Identity(
4568 s.WithOpName("pb_slice_out"),
4569 ops::Slice(s.WithOpName("pb_slice"), stacked, begin_b, sizes_to_end));
4570
4571 // stacked[:, 2:3, :]
4572 auto pc_slice = ops::Identity(
4573 s.WithOpName("pc_slice_out"),
4574 ops::Slice(s.WithOpName("pc_slice"), stacked, begin_c, sizes_to_end));
4575
4576 GrapplerItem item;
4577 item.fetch = {"a",
4578 "b",
4579 "c",
4580 "pa_slice_out",
4581 "pb_slice_out",
4582 "pc_slice_out",
4583 "expanded_a",
4584 "expanded_b",
4585 "expanded_c"};
4586 enum FetchItem {
4587 fA,
4588 fB,
4589 fC,
4590 fASliceOut,
4591 fBSliceOut,
4592 fCSliceOut,
4593 fExpandedA,
4594 fExpandedB,
4595 fExpandedC,
4596 };
4597 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4598 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4599
4600 // stacked[:, 0:1, :] == a.
4601 test::ExpectTensorEqual<float>(tensors_expected[fASliceOut],
4602 tensors_expected[fExpandedA]);
4603 // stacked[:, 1:2, :] == b.
4604 test::ExpectTensorEqual<float>(tensors_expected[fBSliceOut],
4605 tensors_expected[fExpandedB]);
4606 // stacked[:, 2:3, :] == c.
4607 test::ExpectTensorEqual<float>(tensors_expected[fCSliceOut],
4608 tensors_expected[fExpandedC]);
4609
4610 GraphDef output;
4611 ArithmeticOptimizer optimizer;
4612 EnableOnlyRemoveStackSliceSameAxis(&optimizer);
4613 OptimizeAndPrune(&optimizer, &item, &output);
4614
4615 const string kExpandDimsNamePrefix(
4616 "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_p");
4617
4618 for (const auto& node : output.node()) {
4619 if (node.name() == "pa_slice_out") {
4620 ASSERT_EQ(node.input_size(), 1);
4621 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "a_slice"));
4622 } else if (node.name() == "pb_slice_out") {
4623 ASSERT_EQ(node.input_size(), 1);
4624 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "b_slice"));
4625 } else if (node.name() == "pc_slice_out") {
4626 ASSERT_EQ(node.input_size(), 1);
4627 EXPECT_EQ(node.input(0), absl::StrCat(kExpandDimsNamePrefix, "c_slice"));
4628 } else if (absl::StartsWith(node.name(), kExpandDimsNamePrefix)) {
4629 EXPECT_EQ(node.op(), "ExpandDims");
4630 // The input is "a", "b", or "c", as appropriate.
4631 EXPECT_EQ(node.input(0),
4632 node.name().substr(kExpandDimsNamePrefix.size(), 1));
4633 }
4634 }
4635
4636 auto tensors = EvaluateNodes(output, item.fetch);
4637
4638 // stacked[:, 0:1, :] == a.
4639 test::ExpectTensorEqual<float>(tensors[fASliceOut],
4640 tensors_expected[fExpandedA]);
4641
4642 // stacked[:, 1:2, :] == b.
4643 test::ExpectTensorEqual<float>(tensors[fBSliceOut],
4644 tensors_expected[fExpandedB]);
4645 // stacked[:, 2:3, :] == c.
4646 test::ExpectTensorEqual<float>(tensors[fCSliceOut],
4647 tensors_expected[fExpandedC]);
4648 }
4649
TEST_F(ArithmeticOptimizerTest,SimplifyAggregationBFloat16)4650 TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
4651 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4652 Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
4653 Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
4654 Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
4655 Output id = ops::Identity(s.WithOpName("id"), add);
4656
4657 GrapplerItem item;
4658 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4659 item.fetch = {"id"};
4660 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4661 ASSERT_EQ(tensors_expected.size(), 1);
4662
4663 GraphDef output;
4664 ArithmeticOptimizer optimizer;
4665 EnableOnlySimplifyAggregation(&optimizer);
4666 OptimizeAndPrune(&optimizer, &item, &output);
4667
4668 // Extra node created for multiplier.
4669 EXPECT_EQ(output.node_size(), 5);
4670
4671 auto tensors = EvaluateNodes(output, item.fetch);
4672 ASSERT_EQ(tensors.size(), 1);
4673 test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
4674 }
4675
TEST_F(ArithmeticOptimizerTest,SimplifyEmbeddingLookup)4676 TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
4677 for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4678 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4679 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4680 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4681 Output segment_ids =
4682 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4683 Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4684 auto unique = ops::Unique(s.WithOpName("unique"), indices,
4685 /*attrs=*/{unique_idx_type});
4686 Output ids = unique.y;
4687 Output idx = unique.idx;
4688 Output gathered_rows =
4689 ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
4690 Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
4691 idx, segment_ids);
4692 Output id = ops::Identity(s.WithOpName("id"), result);
4693
4694 GrapplerItem item;
4695 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4696 item.fetch = {"id"};
4697 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4698 ASSERT_EQ(tensors_expected.size(), 1);
4699
4700 GraphDef output;
4701 ArithmeticOptimizer optimizer;
4702 EnableOnlySimplifyEmbeddingLookup(&optimizer);
4703 OptimizeAndPrune(&optimizer, &item, &output);
4704
4705 for (const auto& node : output.node()) {
4706 if (node.name() == "result") {
4707 EXPECT_EQ(node.input(0), "embeddings");
4708 EXPECT_EQ(node.input(1), "indices");
4709 }
4710 EXPECT_NE(node.op(), "Unique");
4711 EXPECT_NE(node.op(), "Gather");
4712 }
4713
4714 auto tensors = EvaluateNodes(output, item.fetch);
4715 ASSERT_EQ(tensors.size(), 1);
4716 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4717 }
4718 }
4719
TEST_F(ArithmeticOptimizerTest,SimplifyResourceEmbeddingLookup)4720 TEST_F(ArithmeticOptimizerTest, SimplifyResourceEmbeddingLookup) {
4721 for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
4722 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4723 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4724 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4725 Output segment_ids =
4726 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
4727 Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
4728 auto unique = ops::Unique(s.WithOpName("unique"), indices,
4729 /*attrs=*/{unique_idx_type});
4730 Output ids = unique.y;
4731 Output idx = unique.idx;
4732
4733 auto var =
4734 ops::VarHandleOp(s.WithOpName("var"), DT_FLOAT, TensorShape({2, 2}));
4735 ops::AssignVariableOp assign_op(s.WithOpName("assign_var_handle"), var,
4736 embeddings);
4737
4738 Output gathered_rows = ops::ResourceGather(
4739 s.WithOpName("gathered_rows")
4740 .WithControlDependencies(std::vector<Operation>{assign_op}),
4741 var, ids, DT_FLOAT);
4742 gathered_rows.node()->AddAttr("_class", {"test_class"});
4743 Output result =
4744 ops::SparseSegmentSum(s.WithOpName("result").WithControlDependencies(
4745 std::vector<Operation>{assign_op}),
4746 gathered_rows, idx, segment_ids);
4747 Output id = ops::Identity(s.WithOpName("id"), result);
4748
4749 GrapplerItem item;
4750 item.init_ops.push_back("assign_var_handle");
4751 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4752 item.fetch = {"id"};
4753 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4754 ASSERT_EQ(tensors_expected.size(), 1);
4755
4756 GraphDef output;
4757 ArithmeticOptimizer optimizer;
4758 EnableOnlySimplifyEmbeddingLookup(&optimizer);
4759 OptimizeAndPrune(&optimizer, &item, &output);
4760 bool read_var_node_found = false;
4761 for (const auto& node : output.node()) {
4762 if (node.name() == "result") {
4763 EXPECT_EQ(
4764 node.input(0),
4765 "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result");
4766 EXPECT_EQ(node.input(1), "indices");
4767 }
4768 if (node.op() == "ReadVariableOp") {
4769 read_var_node_found = true;
4770 EXPECT_EQ(node.attr().at("_class").list().s(0), "test_class");
4771 }
4772 EXPECT_NE(node.op(), "Unique");
4773 EXPECT_NE(node.op(), "Gather");
4774 }
4775 EXPECT_TRUE(read_var_node_found);
4776 // Add a control dependency to the ReadVar to do the AssignVar first. This
4777 // shouldn't be an issue in actual use as variables are assumed initialized
4778 // during setup.
4779 for (int i = 0; i < output.node_size(); ++i) {
4780 if (output.node(i).name() ==
4781 "ArithmeticOptimizer/SimplifyEmbeddingLookupStage_ReadVar_result") {
4782 output.mutable_node(i)->add_input("^assign_var_handle");
4783 }
4784 }
4785
4786 auto tensors = EvaluateNodes(output, item.fetch);
4787 ASSERT_EQ(tensors.size(), 1);
4788 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4789 }
4790 }
4791
TEST_F(ArithmeticOptimizerTest,RemoveCastIntoSegmentReduction)4792 TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) {
4793 for (DataType indices_type : {DT_INT32, DT_INT64}) {
4794 for (DataType segment_ids_type : {DT_INT32, DT_INT64}) {
4795 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
4796 Output embeddings = ops::Const(s.WithOpName("embeddings"),
4797 {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
4798 Output indices =
4799 ops::Cast(s.WithOpName("cast_indices"),
4800 ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}),
4801 indices_type);
4802 Output segment_ids = ops::Cast(
4803 s.WithOpName("cast_segment_ids"),
4804 ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}),
4805 segment_ids_type);
4806 Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings,
4807 indices, segment_ids);
4808 Output id = ops::Identity(s.WithOpName("id"), result);
4809
4810 GrapplerItem item;
4811 TF_CHECK_OK(s.ToGraphDef(&item.graph));
4812 item.fetch = {"id"};
4813 auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
4814 ASSERT_EQ(tensors_expected.size(), 1);
4815
4816 GraphDef output;
4817 ArithmeticOptimizer optimizer;
4818 EnableOnlyRemoveCastIntoSegmentReduction(&optimizer);
4819 OptimizeAndPrune(&optimizer, &item, &output);
4820
4821 for (const auto& node : output.node()) {
4822 if (node.name() == "result") {
4823 EXPECT_EQ(node.input(1), "indices");
4824 EXPECT_EQ(node.input(2), "segment_ids");
4825 }
4826 EXPECT_NE(node.op(), "Cast");
4827 }
4828
4829 auto tensors = EvaluateNodes(output, item.fetch);
4830 ASSERT_EQ(tensors.size(), 1);
4831 test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
4832 }
4833 }
4834 }
4835
4836 } // namespace grappler
4837 } // namespace tensorflow
4838