1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/grappler/clusters/cluster.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/mutable_graph_view.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/strcat.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34
35 constexpr char kShuffleDataset[] = "ShuffleDataset";
36 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
37 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
38 constexpr char kRepeatDataset[] = "RepeatDataset";
39 constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
40 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
41
42 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
43
FuseShuffleV1AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)44 Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
45 const NodeDef& repeat_node,
46 MutableGraphView* graph, GraphDef* output,
47 NodeDef* fused_node) {
48 fused_node->set_op(kShuffleAndRepeatDataset);
49 graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
50 fused_node);
51
52 // Set the `input` input argument.
53 fused_node->add_input(shuffle_node.input(0));
54
55 // Set the `buffer_size` input argument.
56 fused_node->add_input(shuffle_node.input(1));
57
58 // Set the `seed` input argument.
59 fused_node->add_input(shuffle_node.input(2));
60
61 // Set the `seed2` input argument.
62 fused_node->add_input(shuffle_node.input(3));
63
64 // Set the `count` input argument.
65 fused_node->add_input(repeat_node.input(1));
66
67 // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
68 // attributes.
69 graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
70 graph_utils::CopyAttribute(kReshuffleEachIteration, shuffle_node, fused_node);
71
72 // Optionally set the `metadata` attribute.
73 graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
74
75 return OkStatus();
76 }
77
FuseShuffleV2AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)78 Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
79 const NodeDef& repeat_node,
80 MutableGraphView* graph, GraphDef* output,
81 NodeDef* fused_node) {
82 fused_node->set_op(kShuffleAndRepeatDatasetV2);
83 graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
84 fused_node);
85
86 NodeDef zero_node = *graph_utils::AddScalarConstNode<int64_t>(0, graph);
87
88 // Set the `input` input argument.
89 fused_node->add_input(shuffle_node.input(0));
90
91 // Set the `buffer_size` input argument.
92 fused_node->add_input(shuffle_node.input(1));
93
94 // Default the `seed` input argument to 0.
95 fused_node->add_input(zero_node.name());
96
97 // Default the `seed2` input argument to 0.
98 fused_node->add_input(zero_node.name());
99
100 // Set the `count` input argument.
101 fused_node->add_input(repeat_node.input(1));
102
103 // Set the `seed_generator` input argument.
104 fused_node->add_input(shuffle_node.input(2));
105
106 // Set `output_types` and `output_shapes` attributes.
107 graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
108
109 // Default the `reshuffle_each_iteration` attribute to true.
110 (*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
111
112 // Optionally set the `metadata` attribute.
113 graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
114
115 return OkStatus();
116 }
117
FuseShuffleV3AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)118 Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
119 const NodeDef& repeat_node,
120 MutableGraphView* graph, GraphDef* output,
121 NodeDef* fused_node) {
122 fused_node->set_op(kShuffleAndRepeatDatasetV2);
123 graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
124 fused_node);
125
126 // Set the `input` input argument.
127 fused_node->add_input(shuffle_node.input(0));
128
129 // Set the `buffer_size` input argument.
130 fused_node->add_input(shuffle_node.input(1));
131
132 // Set the `seed` input argument.
133 fused_node->add_input(shuffle_node.input(2));
134
135 // Set the `seed2` input argument.
136 fused_node->add_input(shuffle_node.input(3));
137
138 // Set the `count` input argument.
139 fused_node->add_input(repeat_node.input(1));
140
141 // Set the `seed_generator` input argument.
142 fused_node->add_input(shuffle_node.input(4));
143
144 // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
145 // attributes.
146 graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
147 graph_utils::CopyAttribute(kReshuffleEachIteration, shuffle_node, fused_node);
148
149 // Optionally set the `metadata` attribute.
150 graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
151
152 return OkStatus();
153 }
154
155 } // namespace
156
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)157 Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
158 Cluster* cluster, const GrapplerItem& item, GraphDef* output,
159 OptimizationStats* stats) {
160 *output = item.graph;
161 MutableGraphView graph(output);
162 absl::flat_hash_set<string> nodes_to_delete;
163
164 for (const NodeDef& repeat_node : item.graph.node()) {
165 if (repeat_node.op() != kRepeatDataset) {
166 continue;
167 }
168
169 const NodeDef& shuffle_node =
170 *graph_utils::GetInputNode(repeat_node, graph);
171
172 NodeDef fused_node;
173 if (shuffle_node.op() == kShuffleDataset) {
174 TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
175 &graph, output, &fused_node));
176 } else if (shuffle_node.op() == kShuffleDatasetV2) {
177 TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
178 &graph, output, &fused_node));
179
180 } else if (shuffle_node.op() == kShuffleDatasetV3) {
181 TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
182 &graph, output, &fused_node));
183 } else {
184 continue;
185 }
186
187 NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
188 TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
189 shuffle_and_repeat_node.name()));
190 // Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
191 // control dependencies.
192 TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
193 shuffle_and_repeat_node.name()));
194
195 // Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
196 // them needs to be preserved).
197 const auto nodes_to_preserve = item.NodesToPreserve();
198 if (nodes_to_preserve.find(shuffle_node.name()) ==
199 nodes_to_preserve.end() &&
200 nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
201 nodes_to_delete.insert(shuffle_node.name());
202 nodes_to_delete.insert(repeat_node.name());
203 }
204 stats->num_changes++;
205 }
206
207 TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
208 return OkStatus();
209 }
210
211 REGISTER_GRAPH_OPTIMIZER_AS(ShuffleAndRepeatFusion,
212 "shuffle_and_repeat_fusion");
213
214 } // namespace grappler
215 } // namespace tensorflow
216