xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/framework/node_def.pb.h"
20 #include "tensorflow/core/framework/node_def_util.h"
21 #include "tensorflow/core/grappler/clusters/cluster.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/mutable_graph_view.h"
24 #include "tensorflow/core/grappler/op_types.h"
25 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
26 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/strcat.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace {
34 
35 constexpr char kShuffleDataset[] = "ShuffleDataset";
36 constexpr char kShuffleDatasetV2[] = "ShuffleDatasetV2";
37 constexpr char kShuffleDatasetV3[] = "ShuffleDatasetV3";
38 constexpr char kRepeatDataset[] = "RepeatDataset";
39 constexpr char kShuffleAndRepeatDataset[] = "ShuffleAndRepeatDataset";
40 constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2";
41 
42 constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
43 
FuseShuffleV1AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)44 Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node,
45                               const NodeDef& repeat_node,
46                               MutableGraphView* graph, GraphDef* output,
47                               NodeDef* fused_node) {
48   fused_node->set_op(kShuffleAndRepeatDataset);
49   graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
50                                       fused_node);
51 
52   // Set the `input` input argument.
53   fused_node->add_input(shuffle_node.input(0));
54 
55   // Set the `buffer_size` input argument.
56   fused_node->add_input(shuffle_node.input(1));
57 
58   // Set the `seed` input argument.
59   fused_node->add_input(shuffle_node.input(2));
60 
61   // Set the `seed2` input argument.
62   fused_node->add_input(shuffle_node.input(3));
63 
64   // Set the `count` input argument.
65   fused_node->add_input(repeat_node.input(1));
66 
67   // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
68   // attributes.
69   graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
70   graph_utils::CopyAttribute(kReshuffleEachIteration, shuffle_node, fused_node);
71 
72   // Optionally set the `metadata` attribute.
73   graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
74 
75   return OkStatus();
76 }
77 
FuseShuffleV2AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)78 Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node,
79                               const NodeDef& repeat_node,
80                               MutableGraphView* graph, GraphDef* output,
81                               NodeDef* fused_node) {
82   fused_node->set_op(kShuffleAndRepeatDatasetV2);
83   graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output,
84                                       fused_node);
85 
86   NodeDef zero_node = *graph_utils::AddScalarConstNode<int64_t>(0, graph);
87 
88   // Set the `input` input argument.
89   fused_node->add_input(shuffle_node.input(0));
90 
91   // Set the `buffer_size` input argument.
92   fused_node->add_input(shuffle_node.input(1));
93 
94   // Default the `seed` input argument to 0.
95   fused_node->add_input(zero_node.name());
96 
97   // Default the `seed2` input argument to 0.
98   fused_node->add_input(zero_node.name());
99 
100   // Set the `count` input argument.
101   fused_node->add_input(repeat_node.input(1));
102 
103   // Set the `seed_generator` input argument.
104   fused_node->add_input(shuffle_node.input(2));
105 
106   // Set `output_types` and `output_shapes` attributes.
107   graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
108 
109   // Default the `reshuffle_each_iteration` attribute to true.
110   (*fused_node->mutable_attr())[kReshuffleEachIteration].set_b(true);
111 
112   // Optionally set the `metadata` attribute.
113   graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
114 
115   return OkStatus();
116 }
117 
FuseShuffleV3AndRepeat(const NodeDef & shuffle_node,const NodeDef & repeat_node,MutableGraphView * graph,GraphDef * output,NodeDef * fused_node)118 Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node,
119                               const NodeDef& repeat_node,
120                               MutableGraphView* graph, GraphDef* output,
121                               NodeDef* fused_node) {
122   fused_node->set_op(kShuffleAndRepeatDatasetV2);
123   graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output,
124                                       fused_node);
125 
126   // Set the `input` input argument.
127   fused_node->add_input(shuffle_node.input(0));
128 
129   // Set the `buffer_size` input argument.
130   fused_node->add_input(shuffle_node.input(1));
131 
132   // Set the `seed` input argument.
133   fused_node->add_input(shuffle_node.input(2));
134 
135   // Set the `seed2` input argument.
136   fused_node->add_input(shuffle_node.input(3));
137 
138   // Set the `count` input argument.
139   fused_node->add_input(repeat_node.input(1));
140 
141   // Set the `seed_generator` input argument.
142   fused_node->add_input(shuffle_node.input(4));
143 
144   // Set `output_types`, `output_shapes`, and `reshuffle_each_iteration`
145   // attributes.
146   graph_utils::CopyShapesAndTypesAttrs(shuffle_node, fused_node);
147   graph_utils::CopyAttribute(kReshuffleEachIteration, shuffle_node, fused_node);
148 
149   // Optionally set the `metadata` attribute.
150   graph_utils::MaybeSetFusedMetadata(shuffle_node, repeat_node, fused_node);
151 
152   return OkStatus();
153 }
154 
155 }  // namespace
156 
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)157 Status ShuffleAndRepeatFusion::OptimizeAndCollectStats(
158     Cluster* cluster, const GrapplerItem& item, GraphDef* output,
159     OptimizationStats* stats) {
160   *output = item.graph;
161   MutableGraphView graph(output);
162   absl::flat_hash_set<string> nodes_to_delete;
163 
164   for (const NodeDef& repeat_node : item.graph.node()) {
165     if (repeat_node.op() != kRepeatDataset) {
166       continue;
167     }
168 
169     const NodeDef& shuffle_node =
170         *graph_utils::GetInputNode(repeat_node, graph);
171 
172     NodeDef fused_node;
173     if (shuffle_node.op() == kShuffleDataset) {
174       TF_RETURN_IF_ERROR(FuseShuffleV1AndRepeat(shuffle_node, repeat_node,
175                                                 &graph, output, &fused_node));
176     } else if (shuffle_node.op() == kShuffleDatasetV2) {
177       TF_RETURN_IF_ERROR(FuseShuffleV2AndRepeat(shuffle_node, repeat_node,
178                                                 &graph, output, &fused_node));
179 
180     } else if (shuffle_node.op() == kShuffleDatasetV3) {
181       TF_RETURN_IF_ERROR(FuseShuffleV3AndRepeat(shuffle_node, repeat_node,
182                                                 &graph, output, &fused_node));
183     } else {
184       continue;
185     }
186 
187     NodeDef& shuffle_and_repeat_node = *graph.AddNode(std::move(fused_node));
188     TF_RETURN_IF_ERROR(graph.UpdateFanouts(repeat_node.name(),
189                                            shuffle_and_repeat_node.name()));
190     // Update shuffle node fanouts to shuffle_and_repeat fanouts to take care of
191     // control dependencies.
192     TF_RETURN_IF_ERROR(graph.UpdateFanouts(shuffle_node.name(),
193                                            shuffle_and_repeat_node.name()));
194 
195     // Mark the `Shuffle` and `Repeat` nodes for removal (as long as neither of
196     // them needs to be preserved).
197     const auto nodes_to_preserve = item.NodesToPreserve();
198     if (nodes_to_preserve.find(shuffle_node.name()) ==
199             nodes_to_preserve.end() &&
200         nodes_to_preserve.find(repeat_node.name()) == nodes_to_preserve.end()) {
201       nodes_to_delete.insert(shuffle_node.name());
202       nodes_to_delete.insert(repeat_node.name());
203     }
204     stats->num_changes++;
205   }
206 
207   TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
208   return OkStatus();
209 }
210 
211 REGISTER_GRAPH_OPTIMIZER_AS(ShuffleAndRepeatFusion,
212                             "shuffle_and_repeat_fusion");
213 
214 }  // namespace grappler
215 }  // namespace tensorflow
216