1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
18 
19 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
20 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
21 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
22 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
23 #include "tensorflow/core/grappler/utils/grappler_test.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 
29 class ArithmeticOptimizerTest : public GrapplerTest {
30  protected:
31   // Optimize a graph using optimizer and prune all the nodes that no
32   // longer have any output consumers.
OptimizeAndPrune(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)33   void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
34                         GraphDef* output) {
35     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
36     item->graph.Swap(output);
37     output->Clear();
38     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
39   }
40 
41   // Run optimizer twice to make sure the rewrite is idempotent.
DedupAndOptimizeTwiceAndPrune(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)42   void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer,
43                                      GrapplerItem* item, GraphDef* output) {
44     TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output));
45     item->graph.Swap(output);
46     output->Clear();
47     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
48     item->graph.Swap(output);
49     output->Clear();
50     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
51     item->graph.Swap(output);
52     output->Clear();
53     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
54   }
55 
56   // Run optimizer twice to make sure the rewrite is idempotent.
OptimizeTwice(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)57   void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item,
58                      GraphDef* output) {
59     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
60     item->graph.Swap(output);
61     output->Clear();
62     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
63   }
64 
65   // Run optimizer twice to make sure the rewrite is idempotent.
66   // Optionally run a constant folding pass before pruning.
67   void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
68                              GraphDef* output, bool const_folding = false) {
69     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
70 
71     item->graph.Swap(output);
72     output->Clear();
73     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
74 
75     if (const_folding) {
76       item->graph.Swap(output);
77       output->Clear();
78       TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
79                        .Optimize(nullptr, *item, output));
80     }
81 
82     item->graph.Swap(output);
83     output->Clear();
84     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
85   }
86 
DisableAddToAddNCombining(ArithmeticOptimizer * optimizer)87   void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
88     optimizer->options_.combine_add_to_addn = false;
89   }
90 
EnableOnlyAddToAddNCombining(ArithmeticOptimizer * optimizer)91   void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
92     DisableAllStages(optimizer);
93     optimizer->options_.combine_add_to_addn = true;
94   }
95 
EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer * optimizer)96   void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
97     DisableAllStages(optimizer);
98     optimizer->options_.fold_conjugate_into_transpose = true;
99   }
100 
EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer * optimizer)101   void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
102     DisableAllStages(optimizer);
103     optimizer->options_.fold_multiply_into_conv = true;
104   }
105 
EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer * optimizer)106   void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
107     DisableAllStages(optimizer);
108     optimizer->options_.fold_transpose_into_matmul = true;
109   }
110 
EnableOnlyHoistCommonFactor(ArithmeticOptimizer * optimizer)111   void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
112     DisableAllStages(optimizer);
113     optimizer->options_.hoist_common_factor_out_of_aggregation = true;
114   }
115 
EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer * optimizer)116   void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
117     DisableAllStages(optimizer);
118     optimizer->options_.minimize_broadcasts = true;
119   }
120 
EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer * optimizer)121   void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
122     DisableAllStages(optimizer);
123     optimizer->options_.remove_identity_transpose = true;
124   }
125 
EnableOnlyRemoveInvolution(ArithmeticOptimizer * optimizer)126   void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
127     DisableAllStages(optimizer);
128     optimizer->options_.remove_involution = true;
129   }
130 
EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer * optimizer)131   void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
132     DisableAllStages(optimizer);
133     optimizer->options_.remove_redundant_bitcast = true;
134   }
135 
EnableOnlyRemoveRedundantCast(ArithmeticOptimizer * optimizer)136   void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
137     DisableAllStages(optimizer);
138     optimizer->options_.remove_redundant_cast = true;
139   }
140 
EnableOnlyReduceUpsamplingDims(ArithmeticOptimizer * optimizer)141   void EnableOnlyReduceUpsamplingDims(ArithmeticOptimizer* optimizer) {
142     DisableAllStages(optimizer);
143     optimizer->options_.reduce_upsampling_dims = true;
144   }
145 
EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer * optimizer)146   void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
147     DisableAllStages(optimizer);
148     optimizer->options_.remove_redundant_reshape = true;
149   }
150 
EnableOnlyRemoveNegation(ArithmeticOptimizer * optimizer)151   void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
152     DisableAllStages(optimizer);
153     optimizer->options_.remove_negation = true;
154   }
155 
EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer * optimizer)156   void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
157     DisableAllStages(optimizer);
158     optimizer->options_.reorder_cast_like_and_value_preserving = true;
159   }
160 
EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer * optimizer)161   void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) {
162     DisableAllStages(optimizer);
163     optimizer->options_.replace_mul_with_tile = true;
164   }
165 
EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer * optimizer)166   void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
167     DisableAllStages(optimizer);
168     optimizer->options_.replace_mul_with_square = true;
169   }
170 
EnableOnlyReplacePackWithTileReshape(ArithmeticOptimizer * optimizer)171   void EnableOnlyReplacePackWithTileReshape(ArithmeticOptimizer* optimizer) {
172     DisableAllStages(optimizer);
173     optimizer->options_.replace_pack_with_tile_reshape = true;
174   }
175 
EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer * optimizer)176   void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
177     DisableAllStages(optimizer);
178     optimizer->options_.hoist_cwise_unary_chains = true;
179   }
180 
EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer * optimizer)181   void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
182     DisableAllStages(optimizer);
183     optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
184   }
185 
EnableOnlyLogSoftmax(ArithmeticOptimizer * optimizer)186   void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) {
187     DisableAllStages(optimizer);
188     optimizer->options_.convert_log_softmax = true;
189   }
190 
EnableOnlyConvertPow(ArithmeticOptimizer * optimizer)191   void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
192     DisableAllStages(optimizer);
193     optimizer->options_.convert_pow = true;
194   }
195 
EnableOnlyFuseSquaredDiff(ArithmeticOptimizer * optimizer)196   void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) {
197     DisableAllStages(optimizer);
198     optimizer->options_.fuse_squared_diff = true;
199   }
200 
EnableOnlyRemoveIdempotent(ArithmeticOptimizer * optimizer)201   void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
202     DisableAllStages(optimizer);
203     optimizer->options_.remove_idempotent = true;
204   }
205 
EnableOnlyRemoveLogicalNot(ArithmeticOptimizer * optimizer)206   void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
207     DisableAllStages(optimizer);
208     optimizer->options_.remove_logical_not = true;
209   }
210 
EnableOnlySimplifyAggregation(ArithmeticOptimizer * optimizer)211   void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
212     DisableAllStages(optimizer);
213     optimizer->options_.simplify_aggregation = true;
214   }
215 
EnableOnlyLog1p(ArithmeticOptimizer * optimizer)216   void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
217     DisableAllStages(optimizer);
218     optimizer->options_.convert_log1p = true;
219   }
220 
EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer * optimizer)221   void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
222     DisableAllStages(optimizer);
223     optimizer->options_.optimize_max_or_min_of_monotonic = true;
224   }
225 
EnableOnlyExpm1(ArithmeticOptimizer * optimizer)226   void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
227     DisableAllStages(optimizer);
228     optimizer->options_.convert_expm1 = true;
229   }
230 
EnableOnlyUnaryOpsComposition(ArithmeticOptimizer * optimizer)231   void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
232     DisableAllStages(optimizer);
233     optimizer->options_.unary_ops_composition = true;
234   }
235 
EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer * optimizer)236   void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) {
237     DisableAllStages(optimizer);
238     optimizer->options_.remove_stack_slice_same_axis = true;
239   }
240 
EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer * optimizer)241   void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) {
242     DisableAllStages(optimizer);
243     optimizer->options_.simplify_embedding_lookup = true;
244   }
245 
EnableOnlyRemoveCastIntoSegmentReduction(ArithmeticOptimizer * optimizer)246   void EnableOnlyRemoveCastIntoSegmentReduction(
247       ArithmeticOptimizer* optimizer) {
248     DisableAllStages(optimizer);
249     optimizer->options_.remove_cast_into_segment_reduction = true;
250   }
251 
252  private:
DisableAllStages(ArithmeticOptimizer * optimizer)253   void DisableAllStages(ArithmeticOptimizer* optimizer) {
254     ArithmeticOptimizer::ArithmeticOptimizerOptions options;
255     options.dedup_computations = false;
256     options.combine_add_to_addn = false;
257     options.convert_sqrt_div_to_rsqrt_mul = false;
258     options.convert_pow = false;
259     options.convert_log1p = false;
260     options.optimize_max_or_min_of_monotonic = false;
261     options.fold_conjugate_into_transpose = false;
262     options.fold_multiply_into_conv = false;
263     options.fold_transpose_into_matmul = false;
264     options.hoist_common_factor_out_of_aggregation = false;
265     options.hoist_cwise_unary_chains = false;
266     options.minimize_broadcasts = false;
267     options.remove_identity_transpose = false;
268     options.remove_involution = false;
269     options.remove_idempotent = false;
270     options.remove_redundant_bitcast = false;
271     options.remove_redundant_cast = false;
272     options.remove_redundant_reshape = false;
273     options.remove_negation = false;
274     options.remove_logical_not = false;
275     options.reorder_cast_like_and_value_preserving = false;
276     options.replace_mul_with_tile = false;
277     options.replace_mul_with_square = false;
278     options.simplify_aggregation = false;
279     options.unary_ops_composition = false;
280     options.simplify_embedding_lookup = false;
281     options.remove_cast_into_segment_reduction = false;
282     optimizer->options_ = options;
283   }
284 };
285 
286 }  // end namespace grappler
287 }  // end namespace tensorflow
288 
289 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
290