1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Configuration for TPU Embedding.
17
18 #include "tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.h"
19
20 #include <string>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/graph/graph.h"
28
29 namespace tensorflow {
30 namespace {
31
32 constexpr absl::string_view kTPUEmbeddingOps[] = {
33 "EnqueueTPUEmbeddingBatch",
34 "EnqueueTPUEmbeddingIntegerBatch",
35 "EnqueueTPUEmbeddingSparseBatch",
36 "EnqueueTPUEmbeddingSparseTensorBatch",
37 "EnqueueTPUEmbeddingRaggedTensorBatch",
38 "EnqueueTPUEmbeddingArbitraryTensorBatch"};
39
40 constexpr absl::string_view kTPURecvOps[] = {"RecvTPUEmbeddingActivations",
41 "XlaRecvTPUEmbeddingActivations"};
42
43 constexpr absl::string_view kTPUGradientSendOps[] = {
44 "SendTPUEmbeddingGradients", "XlaSendTPUEmbeddingGradients"};
45
46 } // namespace
47
Run(const GraphOptimizationPassOptions & options)48 Status UpdateTPUEmbeddingEnqueueOrdinalPass::Run(
49 const GraphOptimizationPassOptions& options) {
50 VLOG(1) << "UpdateTPUEmbeddingEnqueueOrdinalPass::Run";
51
52 // Need the device set to get the number of devices per host.
53 TF_RET_CHECK(options.device_set != nullptr);
54
55 std::vector<Device*> tpu_devices;
56 DeviceNameUtils::ParsedName tpu_device_spec;
57 tpu_device_spec.has_type = true;
58 tpu_device_spec.type = "TPU";
59 options.device_set->FindMatchingDevices(tpu_device_spec, &tpu_devices);
60 if (tpu_devices.empty()) {
61 // If there are no TPUs don't run this pass.
62 return OkStatus();
63 }
64
65 TF_RET_CHECK(options.graph != nullptr);
66 Graph* graph = options.graph->get();
67
68 std::vector<Node*> embedding_nodes;
69 for (Node* node : graph->op_nodes()) {
70 if (absl::c_linear_search(kTPUEmbeddingOps, node->type_string())) {
71 embedding_nodes.emplace_back(node);
72 }
73 }
74
75 // Only run if there are embedding nodes.
76 if (embedding_nodes.empty()) {
77 return OkStatus();
78 }
79
80 DeviceNameUtils::ParsedName single_tpu_device_spec =
81 tpu_devices[0]->parsed_name();
82
83 TF_RET_CHECK(single_tpu_device_spec.has_job);
84
85 // Note that TPUEmbedding is only supported on system with a single TPU slice
86 // (as determined by the 'job' portion of the device spec). Check for that
87 // here just to be sure.
88 for (const auto* tpu_device : tpu_devices) {
89 TF_RET_CHECK(tpu_device->parsed_name().has_job);
90 TF_RET_CHECK(tpu_device->parsed_name().job == single_tpu_device_spec.job)
91 << "Multiple TPU jobs detected. This is not supported for now.";
92 }
93
94 std::vector<Device*> task_devices;
95 single_tpu_device_spec.has_id = false;
96 options.device_set->FindMatchingDevices(single_tpu_device_spec,
97 &task_devices);
98 int64 num_tpus_per_task = task_devices.size();
99
100 for (Node* node : embedding_nodes) {
101 int64 replica_id;
102 if (TryGetNodeAttr(node->attrs(), kXlaReplicaIdAttrName, &replica_id)) {
103 node->AddAttr("device_ordinal", replica_id % num_tpus_per_task);
104 }
105 }
106
107 VLOG(1) << "UpdateTPUEmbeddingEnqueueOrdinalPass::Run() finished";
108 return OkStatus();
109 }
110
111 template <typename A, typename N>
UpdateMapsForModeOverride(const std::string & op,const A & attrs,const N node_identifier,std::map<std::string,N> * enqueue_op,std::map<std::string,bool> * found_recv_op,std::map<std::string,bool> * found_grad_send_op)112 Status UpdateMapsForModeOverride(
113 const std::string& op, const A& attrs, const N node_identifier,
114 std::map<std::string, N>* enqueue_op,
115 std::map<std::string, bool>* found_recv_op,
116 std::map<std::string, bool>* found_grad_send_op) {
117 string layer_call_index;
118 if (TryGetNodeAttr(attrs, "_tpu_embedding_layer", &layer_call_index)) {
119 if ((op == kTPURecvOps[0]) || (op == kTPURecvOps[1])) {
120 // We will prevent users from creating multiple copies of the
121 // TPUEmbedding layer so this should never happen.
122 TF_RET_CHECK(!(*found_recv_op)[layer_call_index])
123 << "Found second receive op for call " << layer_call_index << ". "
124 << "This will happen if you create multiple TPUEmbedding layers. "
125 << "Please ensure that you have only created one TPUEmbedding "
126 << "layer.";
127 (*found_recv_op)[layer_call_index] = true;
128 } else if ((op == kTPUGradientSendOps[0]) ||
129 (op == kTPUGradientSendOps[1])) {
130 TF_RET_CHECK(!(*found_grad_send_op)[layer_call_index])
131 << "Found second send op for call " << layer_call_index << ". "
132 << "This will happen if you create multiple TPUEmbedding layers. "
133 << "Please ensure that you have only created one TPUEmbedding "
134 << "layer.";
135 (*found_grad_send_op)[layer_call_index] = true;
136 } else if (absl::c_linear_search(kTPUEmbeddingOps, op)) {
137 TF_RET_CHECK(enqueue_op->find(layer_call_index) == enqueue_op->end())
138 << "Found second enqueue op for call " << layer_call_index << ". "
139 << "This will happen if you create multiple TPUEmbedding layers. "
140 << "Please ensure that you have only created one TPUEmbedding "
141 << "layer.";
142 (*enqueue_op)[layer_call_index] = node_identifier;
143 }
144 }
145 return OkStatus();
146 }
147
148 template <typename M, typename N>
ComputeEnqueueTrainingStatus(const std::map<std::string,N> & enqueue_op,const std::map<std::string,bool> & found_recv_op,const std::map<std::string,bool> & found_grad_send_op,M * enqueue)149 Status ComputeEnqueueTrainingStatus(
150 const std::map<std::string, N>& enqueue_op,
151 const std::map<std::string, bool>& found_recv_op,
152 const std::map<std::string, bool>& found_grad_send_op, M* enqueue) {
153 TF_RET_CHECK(enqueue_op.size() == found_recv_op.size())
154 << "Enqueue and recv ops should be in a one-to-one corresondence."
155 << "Found " << enqueue_op.size() << " enqueue(s) and "
156 << found_recv_op.size() << " receive(s).";
157 for (const auto& node : enqueue_op) {
158 TF_RET_CHECK(found_recv_op.find(node.first) != found_recv_op.end())
159 << "No receive for enqueue call " << node.first;
160 bool send_exists =
161 (found_grad_send_op.find(node.first) != found_grad_send_op.end());
162 VLOG(1) << "Found call " << node.first
163 << (send_exists ? " with " : " without ") << " send op(s).";
164 // If we have found a send gradient op for that is in the same cluster as
165 // the enqueue op, then this is a training call so set the output to true
166 // for this
167 (*enqueue)[node.second] = send_exists;
168 }
169 return OkStatus();
170 }
171
172 // Get the enqueue ops and their status (training or eval) from a graph.
173 // enqueue is a map from a Graph Node* for an enqueue op to a bool which is true
174 // when the enqueue is part of a TPUEmbedding layer call that contains a send
175 // gradients.
GetEnqueueOpsFromGraph(Graph * graph,absl::flat_hash_map<Node *,bool> * enqueue)176 Status UpdateTPUEmbeddingModePass::GetEnqueueOpsFromGraph(
177 Graph* graph, absl::flat_hash_map<Node*, bool>* enqueue) {
178 // Maps are index by the TPUEmbedding layer's call number.
179 std::map<std::string, Node*> enqueue_op;
180 std::map<std::string, bool> found_recv_op;
181 std::map<std::string, bool> found_grad_send_op;
182
183 for (Node* node : graph->op_nodes()) {
184 TF_RETURN_IF_ERROR(UpdateMapsForModeOverride(
185 node->type_string(), node->attrs(), node, &enqueue_op, &found_recv_op,
186 &found_grad_send_op));
187 // Clear attribute so any further executions of this pass don't activate
188 // pass.
189 node->ClearAttr("_tpu_embedding_layer");
190 }
191
192 return ComputeEnqueueTrainingStatus(enqueue_op, found_recv_op,
193 found_grad_send_op, enqueue);
194 }
195
196 // Update the graph for a specific enqueue op.
UpdateGraphEnqueueOp(bool training,Graph * graph,Node * enqueue)197 Status UpdateTPUEmbeddingModePass::UpdateGraphEnqueueOp(bool training,
198 Graph* graph,
199 Node* enqueue) {
200 // When using the layer, the mode override input is a SelectV2 op (unless this
201 // pass has already run), which takes a training and eval op as input. We will
202 // simply short circut the SelectV2 and take input from the correct op.
203 const Edge* select_edge;
204 TF_RETURN_IF_ERROR(
205 enqueue->input_edge(enqueue->num_inputs() - 1, &select_edge));
206 if (select_edge->src()->type_string() == "SelectV2") {
207 TF_RET_CHECK(select_edge->src()->num_inputs() == 3);
208 Node* mode;
209 TF_RETURN_IF_ERROR(select_edge->src()->input_node(training ? 1 : 2, &mode));
210 graph->AddEdge(mode, 0, enqueue, enqueue->num_inputs() - 1);
211 graph->RemoveEdge(select_edge);
212 }
213
214 return OkStatus();
215 }
216
217 // Get the enqueue ops and their status (training or eval) from a function def.
218 // The enqueue map is indexed by the position of the enqueue op in the
219 // function's node_def array.
GetEnqueueOpsFromFunctionDef(FunctionDef * function,std::map<int,bool> * enqueue)220 Status UpdateTPUEmbeddingModePass::GetEnqueueOpsFromFunctionDef(
221 FunctionDef* function, std::map<int, bool>* enqueue) {
222 std::map<std::string, int> enqueue_op;
223 std::map<std::string, bool> found_recv_op;
224 std::map<std::string, bool> found_grad_send_op;
225
226 std::string cluster;
227 for (int i = 0; i < function->node_def_size(); ++i) {
228 const NodeDef& node = function->node_def(i);
229 TF_RETURN_IF_ERROR(UpdateMapsForModeOverride(
230 node.op(), node, i, &enqueue_op, &found_recv_op, &found_grad_send_op));
231 // Clear attribute so any further executions of this pass don't activate
232 // pass.
233 function->mutable_node_def(i)->mutable_attr()->erase(
234 "_tpu_embedding_layer");
235 }
236
237 return ComputeEnqueueTrainingStatus(enqueue_op, found_recv_op,
238 found_grad_send_op, enqueue);
239 }
240
241 // Update the function def for a specific enqueue op.
UpdateFunctionDefEnqueueOp(int enqueue,bool training,FunctionDef * function,bool * updated)242 Status UpdateTPUEmbeddingModePass::UpdateFunctionDefEnqueueOp(
243 int enqueue, bool training, FunctionDef* function, bool* updated) {
244 // When using the layer, the mode override input is a SelectV2 op,
245 // which takes a training and eval op as input. We will simply short circut
246 // the SelectV2 and take input from the correct op.
247 NodeDef* node = function->mutable_node_def(enqueue);
248 int mode_override = node->input_size() - 1;
249 while ((mode_override >= 0) && (node->input(mode_override).empty() ||
250 (node->input(mode_override)[0] == '^'))) {
251 mode_override--;
252 }
253 TF_RET_CHECK(mode_override >= 0) << "Can't find non-control input to "
254 << "enqueue.";
255 TF_RET_CHECK(!node->input(mode_override).empty());
256
257 // Find input node
258 string select_name = std::vector<std::string>(
259 absl::StrSplit(node->input(mode_override), ':'))[0];
260 int select = 0;
261 while ((select < function->node_def_size()) &&
262 (function->node_def(select).name() != select_name)) {
263 select++;
264 }
265 TF_RET_CHECK(select < function->node_def_size())
266 << "Unable to find enqueue input node " << select_name << " in function "
267 << function->signature().name();
268 if (function->node_def(select).op() == "SelectV2") {
269 // Make the mode override input the same as the correct input of the
270 // select v2.
271 (*node->mutable_input(mode_override)) =
272 function->node_def(select).input(training ? 1 : 2);
273 *updated = true;
274 }
275
276 return OkStatus();
277 }
278
Run(const GraphOptimizationPassOptions & options)279 Status UpdateTPUEmbeddingModePass::Run(
280 const GraphOptimizationPassOptions& options) {
281 // Updates the Enqueue ops when using a layer to set the mode override
282 // behavior depending on the existence of send gradients ops.
283 // Note we only do this when a layer is used (all BC ops with an integer
284 // attribute "_tpu_embedding_layer" that is incremented per call, so we can
285 // easily associate the various ops).
286 //
287 // Note that the BC ops can be in the Graph or in the FunctionDef.
288 // If they are in the graph at stage 0, this means that there as no control
289 // flow containing them (i.e. a host loop). In this case, we group together
290 // ops with the same "_tpu_embedding_layer" tag.
291 //
292 // We also search all FunctionDefs. Note that as the ops are all created in
293 // the layer's call, a cluster of TPUEmbedding ops won't be split across
294 // different FunctionDefs.
295
296 VLOG(1) << "UpdateTPUEmbeddingModePass::Run";
297
298 TF_RET_CHECK(options.graph != nullptr);
299
300 // First process the graph
301 Graph* graph = options.graph->get();
302 absl::flat_hash_map<Node*, bool> enqueue_nodes;
303 TF_RETURN_IF_ERROR(GetEnqueueOpsFromGraph(graph, &enqueue_nodes));
304 for (const auto& enqueue : enqueue_nodes) {
305 TF_RETURN_IF_ERROR(
306 UpdateGraphEnqueueOp(enqueue.second, graph, enqueue.first));
307 }
308
309 for (const auto& fname : options.flib_def->ListFunctionNames()) {
310 FunctionDef fdef_copy(*options.flib_def->Find(fname));
311 std::map<int, bool> enqueue_nodes;
312 TF_RETURN_IF_ERROR(
313 GetEnqueueOpsFromFunctionDef(&fdef_copy, &enqueue_nodes));
314 bool updated = false;
315 for (const auto& enqueue : enqueue_nodes) {
316 TF_RETURN_IF_ERROR(UpdateFunctionDefEnqueueOp(
317 enqueue.first, enqueue.second, &fdef_copy, &updated));
318 }
319
320 if (updated) {
321 TF_RETURN_IF_ERROR(options.flib_def->ReplaceFunction(fname, fdef_copy));
322 }
323 }
324
325 VLOG(1) << "UpdateTPUEmbeddingModePass::Run() finished";
326 return OkStatus();
327 }
328
329 } // namespace tensorflow
330