1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Configuration for TPU Embedding.
17 
18 #include "tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.h"
19 
20 #include <string>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/graph/graph.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 constexpr absl::string_view kTPUEmbeddingOps[] = {
33     "EnqueueTPUEmbeddingBatch",
34     "EnqueueTPUEmbeddingIntegerBatch",
35     "EnqueueTPUEmbeddingSparseBatch",
36     "EnqueueTPUEmbeddingSparseTensorBatch",
37     "EnqueueTPUEmbeddingRaggedTensorBatch",
38     "EnqueueTPUEmbeddingArbitraryTensorBatch"};
39 
40 constexpr absl::string_view kTPURecvOps[] = {"RecvTPUEmbeddingActivations",
41                                              "XlaRecvTPUEmbeddingActivations"};
42 
43 constexpr absl::string_view kTPUGradientSendOps[] = {
44     "SendTPUEmbeddingGradients", "XlaSendTPUEmbeddingGradients"};
45 
46 }  // namespace
47 
Run(const GraphOptimizationPassOptions & options)48 Status UpdateTPUEmbeddingEnqueueOrdinalPass::Run(
49     const GraphOptimizationPassOptions& options) {
50   VLOG(1) << "UpdateTPUEmbeddingEnqueueOrdinalPass::Run";
51 
52   // Need the device set to get the number of devices per host.
53   TF_RET_CHECK(options.device_set != nullptr);
54 
55   std::vector<Device*> tpu_devices;
56   DeviceNameUtils::ParsedName tpu_device_spec;
57   tpu_device_spec.has_type = true;
58   tpu_device_spec.type = "TPU";
59   options.device_set->FindMatchingDevices(tpu_device_spec, &tpu_devices);
60   if (tpu_devices.empty()) {
61     // If there are no TPUs don't run this pass.
62     return OkStatus();
63   }
64 
65   TF_RET_CHECK(options.graph != nullptr);
66   Graph* graph = options.graph->get();
67 
68   std::vector<Node*> embedding_nodes;
69   for (Node* node : graph->op_nodes()) {
70     if (absl::c_linear_search(kTPUEmbeddingOps, node->type_string())) {
71       embedding_nodes.emplace_back(node);
72     }
73   }
74 
75   // Only run if there are embedding nodes.
76   if (embedding_nodes.empty()) {
77     return OkStatus();
78   }
79 
80   DeviceNameUtils::ParsedName single_tpu_device_spec =
81       tpu_devices[0]->parsed_name();
82 
83   TF_RET_CHECK(single_tpu_device_spec.has_job);
84 
85   // Note that TPUEmbedding is only supported on system with a single TPU slice
86   // (as determined by the 'job' portion of the device spec). Check for that
87   // here just to be sure.
88   for (const auto* tpu_device : tpu_devices) {
89     TF_RET_CHECK(tpu_device->parsed_name().has_job);
90     TF_RET_CHECK(tpu_device->parsed_name().job == single_tpu_device_spec.job)
91         << "Multiple TPU jobs detected. This is not supported for now.";
92   }
93 
94   std::vector<Device*> task_devices;
95   single_tpu_device_spec.has_id = false;
96   options.device_set->FindMatchingDevices(single_tpu_device_spec,
97                                           &task_devices);
98   int64 num_tpus_per_task = task_devices.size();
99 
100   for (Node* node : embedding_nodes) {
101     int64 replica_id;
102     if (TryGetNodeAttr(node->attrs(), kXlaReplicaIdAttrName, &replica_id)) {
103       node->AddAttr("device_ordinal", replica_id % num_tpus_per_task);
104     }
105   }
106 
107   VLOG(1) << "UpdateTPUEmbeddingEnqueueOrdinalPass::Run() finished";
108   return OkStatus();
109 }
110 
111 template <typename A, typename N>
UpdateMapsForModeOverride(const std::string & op,const A & attrs,const N node_identifier,std::map<std::string,N> * enqueue_op,std::map<std::string,bool> * found_recv_op,std::map<std::string,bool> * found_grad_send_op)112 Status UpdateMapsForModeOverride(
113     const std::string& op, const A& attrs, const N node_identifier,
114     std::map<std::string, N>* enqueue_op,
115     std::map<std::string, bool>* found_recv_op,
116     std::map<std::string, bool>* found_grad_send_op) {
117   string layer_call_index;
118   if (TryGetNodeAttr(attrs, "_tpu_embedding_layer", &layer_call_index)) {
119     if ((op == kTPURecvOps[0]) || (op == kTPURecvOps[1])) {
120       // We will prevent users from creating multiple copies of the
121       // TPUEmbedding layer so this should never happen.
122       TF_RET_CHECK(!(*found_recv_op)[layer_call_index])
123           << "Found second receive op for call " << layer_call_index << ". "
124           << "This will happen if you create multiple TPUEmbedding layers. "
125           << "Please ensure that you have only created one TPUEmbedding "
126           << "layer.";
127       (*found_recv_op)[layer_call_index] = true;
128     } else if ((op == kTPUGradientSendOps[0]) ||
129                (op == kTPUGradientSendOps[1])) {
130       TF_RET_CHECK(!(*found_grad_send_op)[layer_call_index])
131           << "Found second send op for call " << layer_call_index << ". "
132           << "This will happen if you create multiple TPUEmbedding layers. "
133           << "Please ensure that you have only created one TPUEmbedding "
134           << "layer.";
135       (*found_grad_send_op)[layer_call_index] = true;
136     } else if (absl::c_linear_search(kTPUEmbeddingOps, op)) {
137       TF_RET_CHECK(enqueue_op->find(layer_call_index) == enqueue_op->end())
138           << "Found second enqueue op for call " << layer_call_index << ". "
139           << "This will happen if you create multiple TPUEmbedding layers. "
140           << "Please ensure that you have only created one TPUEmbedding "
141           << "layer.";
142       (*enqueue_op)[layer_call_index] = node_identifier;
143     }
144   }
145   return OkStatus();
146 }
147 
148 template <typename M, typename N>
ComputeEnqueueTrainingStatus(const std::map<std::string,N> & enqueue_op,const std::map<std::string,bool> & found_recv_op,const std::map<std::string,bool> & found_grad_send_op,M * enqueue)149 Status ComputeEnqueueTrainingStatus(
150     const std::map<std::string, N>& enqueue_op,
151     const std::map<std::string, bool>& found_recv_op,
152     const std::map<std::string, bool>& found_grad_send_op, M* enqueue) {
153   TF_RET_CHECK(enqueue_op.size() == found_recv_op.size())
154       << "Enqueue and recv ops should be in a one-to-one corresondence."
155       << "Found " << enqueue_op.size() << " enqueue(s) and "
156       << found_recv_op.size() << " receive(s).";
157   for (const auto& node : enqueue_op) {
158     TF_RET_CHECK(found_recv_op.find(node.first) != found_recv_op.end())
159         << "No receive for enqueue call " << node.first;
160     bool send_exists =
161         (found_grad_send_op.find(node.first) != found_grad_send_op.end());
162     VLOG(1) << "Found call " << node.first
163         << (send_exists ? " with " : " without ") << " send op(s).";
164     // If we have found a send gradient op for that is in the same cluster as
165     // the enqueue op, then this is a training call so set the output to true
166     // for this
167     (*enqueue)[node.second] = send_exists;
168   }
169   return OkStatus();
170 }
171 
172 // Get the enqueue ops and their status (training or eval) from a graph.
173 // enqueue is a map from a Graph Node* for an enqueue op to a bool which is true
174 // when the enqueue is part of a TPUEmbedding layer call that contains a send
175 // gradients.
GetEnqueueOpsFromGraph(Graph * graph,absl::flat_hash_map<Node *,bool> * enqueue)176 Status UpdateTPUEmbeddingModePass::GetEnqueueOpsFromGraph(
177     Graph* graph, absl::flat_hash_map<Node*, bool>* enqueue) {
178   // Maps are index by the TPUEmbedding layer's call number.
179   std::map<std::string, Node*> enqueue_op;
180   std::map<std::string, bool> found_recv_op;
181   std::map<std::string, bool> found_grad_send_op;
182 
183   for (Node* node : graph->op_nodes()) {
184     TF_RETURN_IF_ERROR(UpdateMapsForModeOverride(
185         node->type_string(), node->attrs(), node, &enqueue_op, &found_recv_op,
186         &found_grad_send_op));
187     // Clear attribute so any further executions of this pass don't activate
188     // pass.
189     node->ClearAttr("_tpu_embedding_layer");
190   }
191 
192   return ComputeEnqueueTrainingStatus(enqueue_op, found_recv_op,
193                                       found_grad_send_op, enqueue);
194 }
195 
196 // Update the graph for a specific enqueue op.
UpdateGraphEnqueueOp(bool training,Graph * graph,Node * enqueue)197 Status UpdateTPUEmbeddingModePass::UpdateGraphEnqueueOp(bool training,
198                                                      Graph* graph,
199                                                      Node* enqueue) {
200   // When using the layer, the mode override input is a SelectV2 op (unless this
201   // pass has already run), which takes a training and eval op as input. We will
202   // simply short circut the SelectV2 and take input from the correct op.
203   const Edge* select_edge;
204   TF_RETURN_IF_ERROR(
205       enqueue->input_edge(enqueue->num_inputs() - 1, &select_edge));
206   if (select_edge->src()->type_string() == "SelectV2") {
207     TF_RET_CHECK(select_edge->src()->num_inputs() == 3);
208     Node* mode;
209     TF_RETURN_IF_ERROR(select_edge->src()->input_node(training ? 1 : 2, &mode));
210     graph->AddEdge(mode, 0, enqueue, enqueue->num_inputs() - 1);
211     graph->RemoveEdge(select_edge);
212   }
213 
214   return OkStatus();
215 }
216 
217 // Get the enqueue ops and their status (training or eval) from a function def.
218 // The enqueue map is indexed by the position of the enqueue op in the
219 // function's node_def array.
GetEnqueueOpsFromFunctionDef(FunctionDef * function,std::map<int,bool> * enqueue)220 Status UpdateTPUEmbeddingModePass::GetEnqueueOpsFromFunctionDef(
221     FunctionDef* function, std::map<int, bool>* enqueue) {
222   std::map<std::string, int> enqueue_op;
223   std::map<std::string, bool> found_recv_op;
224   std::map<std::string, bool> found_grad_send_op;
225 
226   std::string cluster;
227   for (int i = 0; i < function->node_def_size(); ++i) {
228     const NodeDef& node = function->node_def(i);
229     TF_RETURN_IF_ERROR(UpdateMapsForModeOverride(
230         node.op(), node, i, &enqueue_op, &found_recv_op, &found_grad_send_op));
231     // Clear attribute so any further executions of this pass don't activate
232     // pass.
233     function->mutable_node_def(i)->mutable_attr()->erase(
234         "_tpu_embedding_layer");
235   }
236 
237   return ComputeEnqueueTrainingStatus(enqueue_op, found_recv_op,
238                                       found_grad_send_op, enqueue);
239 }
240 
241 // Update the function def for a specific enqueue op.
UpdateFunctionDefEnqueueOp(int enqueue,bool training,FunctionDef * function,bool * updated)242 Status UpdateTPUEmbeddingModePass::UpdateFunctionDefEnqueueOp(
243     int enqueue, bool training, FunctionDef* function, bool* updated) {
244   // When using the layer, the mode override input is a SelectV2 op,
245   // which takes a training and eval op as input. We will simply short circut
246   // the SelectV2 and take input from the correct op.
247   NodeDef* node = function->mutable_node_def(enqueue);
248   int mode_override = node->input_size() - 1;
249   while ((mode_override >= 0) && (node->input(mode_override).empty() ||
250                                   (node->input(mode_override)[0] == '^'))) {
251     mode_override--;
252   }
253   TF_RET_CHECK(mode_override >= 0) << "Can't find non-control input to "
254                                    << "enqueue.";
255   TF_RET_CHECK(!node->input(mode_override).empty());
256 
257   // Find input node
258   string select_name = std::vector<std::string>(
259       absl::StrSplit(node->input(mode_override), ':'))[0];
260   int select = 0;
261   while ((select < function->node_def_size()) &&
262          (function->node_def(select).name() != select_name)) {
263     select++;
264   }
265   TF_RET_CHECK(select < function->node_def_size())
266       << "Unable to find enqueue input node " << select_name << " in function "
267       << function->signature().name();
268   if (function->node_def(select).op() == "SelectV2") {
269     // Make the mode override input the same as the correct input of the
270     // select v2.
271     (*node->mutable_input(mode_override)) =
272         function->node_def(select).input(training ? 1 : 2);
273     *updated = true;
274   }
275 
276   return OkStatus();
277 }
278 
Run(const GraphOptimizationPassOptions & options)279 Status UpdateTPUEmbeddingModePass::Run(
280     const GraphOptimizationPassOptions& options) {
281   // Updates the Enqueue ops when using a layer to set the mode override
282   // behavior depending on the existence of send gradients ops.
283   // Note we only do this when a layer is used (all BC ops with an integer
284   // attribute "_tpu_embedding_layer" that is incremented per call, so we can
285   // easily associate the various ops).
286   //
287   // Note that the BC ops can be in the Graph or in the FunctionDef.
288   // If they are in the graph at stage 0, this means that there as no control
289   // flow containing them (i.e. a host loop). In this case, we group together
290   // ops with the same "_tpu_embedding_layer" tag.
291   //
292   // We also search all FunctionDefs. Note that as the ops are all created in
293   // the layer's call, a cluster of TPUEmbedding ops won't be split across
294   // different FunctionDefs.
295 
296   VLOG(1) << "UpdateTPUEmbeddingModePass::Run";
297 
298   TF_RET_CHECK(options.graph != nullptr);
299 
300   // First process the graph
301   Graph* graph = options.graph->get();
302   absl::flat_hash_map<Node*, bool> enqueue_nodes;
303   TF_RETURN_IF_ERROR(GetEnqueueOpsFromGraph(graph, &enqueue_nodes));
304   for (const auto& enqueue : enqueue_nodes) {
305     TF_RETURN_IF_ERROR(
306         UpdateGraphEnqueueOp(enqueue.second, graph, enqueue.first));
307   }
308 
309   for (const auto& fname : options.flib_def->ListFunctionNames()) {
310     FunctionDef fdef_copy(*options.flib_def->Find(fname));
311     std::map<int, bool> enqueue_nodes;
312     TF_RETURN_IF_ERROR(
313         GetEnqueueOpsFromFunctionDef(&fdef_copy, &enqueue_nodes));
314     bool updated = false;
315     for (const auto& enqueue : enqueue_nodes) {
316       TF_RETURN_IF_ERROR(UpdateFunctionDefEnqueueOp(
317           enqueue.first, enqueue.second, &fdef_copy, &updated));
318     }
319 
320     if (updated) {
321       TF_RETURN_IF_ERROR(options.flib_def->ReplaceFunction(fname, fdef_copy));
322     }
323   }
324 
325   VLOG(1) << "UpdateTPUEmbeddingModePass::Run() finished";
326   return OkStatus();
327 }
328 
329 }  // namespace tensorflow
330