xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/collective_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string>
16 #include <utility>
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/collective.h"
22 #include "tensorflow/core/framework/device_attributes.pb.h"
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/op_requires.h"
26 #include "tensorflow/core/framework/resource_handle.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/framework/tensor_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/refcount.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 
39 namespace {
40 
CollectiveKey(OpKernelContext * ctx,int32_t group_key,int32_t instance_key)41 static string CollectiveKey(OpKernelContext* ctx, int32_t group_key,
42                             int32_t instance_key) {
43   return strings::StrCat(group_key, ":", instance_key, ":",
44                          ctx->frame_iter().frame_id, ":",
45                          ctx->frame_iter().iter_id);
46 }
47 
BuildOpKernel(OpKernelConstruction * c,const string & name,NodeDef * sub_node)48 static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
49                                                const string& name,
50                                                NodeDef* sub_node) {
51   std::unique_ptr<OpKernel> k;
52   if (name.empty() || name == "Id") return k;
53   sub_node->set_name(name);
54   sub_node->set_op(name);
55   Status status;
56   k = CreateOpKernel(c->device_type(), c->device(),
57                      c->device()->GetAllocator(AllocatorAttributes()),
58                      *sub_node, c->graph_def_version(), &status);
59   if (!status.ok()) {
60     c->CtxFailureWithWarning(errors::Internal(
61         "Failed to build OpKernel for ", name, " : ", status.error_message()));
62   }
63   return k;
64 }
65 
66 class CollectiveOpV1Kernel : public AsyncOpKernel {
67  public:
CollectiveOpV1Kernel(OpKernelConstruction * c)68   explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
69       : AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
70 
~CollectiveOpV1Kernel()71   ~CollectiveOpV1Kernel() override { col_params_->Unref(); }
72 
ComputeAsync(OpKernelContext * c,DoneCallback done)73   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
74     CollectiveExecutor* col_exec = c->collective_executor();
75     OP_REQUIRES_ASYNC(
76         c, col_exec,
77         errors::Internal(
78             "Failed to get CollectiveExecutor from OpKernelContext for Op ",
79             name_),
80         done);
81     const CancellationToken token =
82         c->cancellation_manager()->get_cancellation_token();
83     const bool already_cancelled =
84         !c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
85           // We must call StartAbort() within the callback. StartAbort() relies
86           // on resources that may be deallocated if all execution of a graph is
87           // finished.
88           col_exec->StartAbort(errors::Cancelled("op cancelled"));
89         });
90     OP_REQUIRES_ASYNC(c, !already_cancelled,
91                       errors::Cancelled("op cancelled ", name_), done);
92 
93     auto deregister_and_done = [c, token, done = std::move(done)]() {
94       // Once done() is called, StartAbort() won't have any effect, so we
95       // don't need to block on the deregistration. Also StartAbort() may call
96       // done() and DeregisterCallback may deadlock.
97       c->cancellation_manager()->TryDeregisterCallback(token);
98       done();
99     };
100     ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
101   }
102 
103   // A string encoding instance, frame and iter to be handed off to
104   // the implementation for use in generating RecvBuf keys.
GetCollectiveKey(OpKernelContext * c)105   string GetCollectiveKey(OpKernelContext* c) {
106     return CollectiveKey(c, col_params_->group.group_key,
107                          col_params_->instance.instance_key);
108   }
109 
110   // Returns false if calling invocation of ComputeAsync should return
111   // immediately.
CanProceedWithCompute(OpKernelContext * c,CollectiveExecutor * col_exec,const DoneCallback & done)112   bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
113                              const DoneCallback& done) {
114     if (col_params_->group.group_size > col_params_->group.members.size()) {
115       // This is the first invocation: Finish initializing col_params_.
116       // Schedule the `CompleteParamsAsync` call on a work queue that can handle
117       // blocking work because it's not guaranteed that this call cannot block.
118       c->collective_executor()->RunClosure([this, c, col_exec, done]() {
119         VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
120                 << col_params_->name << " device " << c->device()->name()
121                 << " group " << col_params_->group.group_key << " instance "
122                 << col_params_->instance.instance_key;
123         col_exec->CompleteParamsAsync(
124             c->device()->attributes(), col_params_, c->cancellation_manager(),
125             [this, c, done](const Status& s) {
126               if (s.ok()) {
127                 col_params_->instance.impl_details.dependencies = dependencies_;
128                 ComputeAsync(c, done);
129               } else {
130                 c->SetStatus(s);
131                 done();
132               }
133             });
134       });
135       return false;
136     }
137     return true;
138   }
139 
140  protected:
141   virtual void ComputeAsyncImpl(OpKernelContext* c,
142                                 CollectiveExecutor* col_exec,
143                                 DoneCallback done) = 0;
144 
145   string name_;
146   CollectiveParams* col_params_;
147   std::vector<int32> dependencies_;
148 };
149 
150 class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
151  public:
CollectiveGatherOpKernel(OpKernelConstruction * c)152   explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
153       : CollectiveOpV1Kernel(c) {
154     col_params_->instance.type = GATHER_COLLECTIVE;
155     OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
156     OP_REQUIRES(
157         c, col_params_->group.group_size > 0,
158         errors::InvalidArgument("group_size must be positive integer but got ",
159                                 col_params_->group.group_size));
160     OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
161     OP_REQUIRES_OK(
162         c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
163     OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
164     OP_REQUIRES_OK(
165         c, c->GetAttr("communication_hint",
166                       &col_params_->instance.impl_details.communication_hint));
167     OP_REQUIRES_OK(
168         c, c->GetAttr("timeout_seconds",
169                       &col_params_->instance.impl_details.timeout_seconds));
170     const NodeDef& real_node = c->def();
171     col_params_->name = strings::StrCat(real_node.name(), ": Gather");
172     col_params_->group.device_type = c->device_type();
173   }
174 
175  protected:
ComputeAsyncImpl(OpKernelContext * c,CollectiveExecutor * col_exec,DoneCallback done)176   void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
177                         DoneCallback done) override {
178     auto output_shape = c->input(0).shape();
179     OP_REQUIRES_ASYNC(c, output_shape.dims() > 0,
180                       errors::InvalidArgument("input should have rank > 0, ",
181                                               "recieved ", output_shape.dims()),
182                       done);
183     output_shape.set_dim(
184         0, output_shape.dim_size(0) * col_params_->group.group_size);
185     col_params_->instance.shape = output_shape;
186 
187     // Allocate output on the first pass through this function.  This must be
188     // done immediately, while we're still in the executor thread.  Otherwise
189     // the memory is not guaranteed to be unused by any concurrently executing
190     // GPU kernel.
191     if (c->mutable_output(0) == nullptr) {
192       // Allocate the output tensor.
193       Tensor* output = nullptr;
194       OP_REQUIRES_OK_ASYNC(
195           c, c->allocate_output(0, col_params_->instance.shape, &output), done);
196     }
197     if (!CanProceedWithCompute(c, col_exec, done)) return;
198 
199     auto actual_done = [c, col_params = col_params_, done](const Status& s) {
200       VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
201               << c->op_kernel().name() << " device " << c->device()->name()
202               << " group " << col_params->group.group_key << " instance "
203               << col_params->instance.instance_key << " status " << s;
204       col_params->Unref();
205       OP_REQUIRES_OK_ASYNC(c, s, done);
206       done();
207     };
208     VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
209             << col_params_->name << " device " << c->device()->name()
210             << " group " << col_params_->group.group_key << " instance "
211             << col_params_->instance.instance_key;
212     col_params_->Ref();
213     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
214   }
215 
216  private:
217   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveGatherOpKernel);
218 };
219 
220 REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
221                         CollectiveGatherOpKernel);
222 REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
223                         CollectiveGatherOpKernel);
224 
225 class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
226  public:
CollectiveReduceOpKernel(OpKernelConstruction * c)227   explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
228       : CollectiveOpV1Kernel(c) {
229     col_params_->instance.type = REDUCTION_COLLECTIVE;
230     OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
231     OP_REQUIRES(
232         c, col_params_->group.group_size > 0,
233         errors::InvalidArgument("group_size must be positive integer but got ",
234                                 col_params_->group.group_size));
235     OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
236     OP_REQUIRES_OK(
237         c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
238     OP_REQUIRES_OK(
239         c, c->GetAttr("subdiv_offsets",
240                       &col_params_->instance.impl_details.subdiv_offsets));
241     string merge_op_name;
242     OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
243     if (merge_op_name == "Max") {
244       merge_op_name = "Maximum";
245     } else if (merge_op_name == "Min") {
246       merge_op_name = "Minimum";
247     }
248     string final_op_name;
249     OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
250     OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
251                 errors::InvalidArgument(
252                     "final_op must be one of {\"Id\", \"Div\"} but got ",
253                     final_op_name));
254     OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
255     OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
256     OP_REQUIRES_OK(
257         c, c->GetAttr("communication_hint",
258                       &col_params_->instance.impl_details.communication_hint));
259     OP_REQUIRES_OK(
260         c, c->GetAttr("timeout_seconds",
261                       &col_params_->instance.impl_details.timeout_seconds));
262     VLOG(2) << "CollectiveReduce instance "
263             << col_params_->instance.instance_key << " merge_op "
264             << merge_op_name << " final_op " << final_op_name
265             << " communication_hint "
266             << col_params_->instance.impl_details.communication_hint
267             << " timeout "
268             << col_params_->instance.impl_details.timeout_seconds;
269 
270     const NodeDef& real_node = c->def();
271     col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
272                                         merge_op_name, ",", final_op_name, ")");
273     col_params_->group.device_type = c->device_type();
274 
275     // Find the OpKernels by name, type and device type.
276     NodeDef sub_node;
277     // The merge_op takes two inputs
278     sub_node.add_input(real_node.input(0));
279     sub_node.add_input(real_node.input(0));
280     sub_node.set_device(real_node.device());
281     SetAttrValue(col_params_->instance.data_type,
282                  &(*sub_node.mutable_attr())["T"]);
283     merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
284     final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
285     col_params_->merge_op = merge_op_.get();
286     col_params_->final_op = final_op_.get();
287   }
288 
289  protected:
ComputeAsyncImpl(OpKernelContext * c,CollectiveExecutor * col_exec,DoneCallback done)290   void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
291                         DoneCallback done) override {
292     // Allocate output on the first pass through this function.  This must be
293     // done immediately, while we're still in the executor thread.  Otherwise
294     // the memory is not guaranteed to be unused by any concurrently executing
295     // GPU kernel.
296     if (c->mutable_output(0) == nullptr) {
297       // Allocate the output tensor, trying to reuse the input.
298       Tensor* output = nullptr;
299       OP_REQUIRES_OK_ASYNC(c,
300                            c->forward_input_or_allocate_output(
301                                {0}, 0, c->input(0).shape(), &output),
302                            done);
303       col_params_->instance.shape = c->input(0).shape();
304     }
305     if (!CanProceedWithCompute(c, col_exec, done)) return;
306 
307     auto actual_done = [c, col_params = col_params_, done](const Status& s) {
308       VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
309               << c->op_kernel().name() << " device " << c->device()->name()
310               << " group " << col_params->group.group_key << " instance "
311               << col_params->instance.instance_key << " status " << s;
312       col_params->Unref();
313       OP_REQUIRES_OK_ASYNC(c, s, done);
314       done();
315     };
316     VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
317             << col_params_->name << " device " << c->device()->name()
318             << " group " << col_params_->group.group_key << " instance "
319             << col_params_->instance.instance_key;
320     col_params_->Ref();
321     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
322   }
323 
324  private:
325   std::unique_ptr<OpKernel> merge_op_;
326   std::unique_ptr<OpKernel> final_op_;
327   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
328 };
329 
330 REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
331                         CollectiveReduceOpKernel);
332 REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
333                         CollectiveReduceOpKernel);
334 
335 class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
336  public:
CollectiveBcastSendOpKernel(OpKernelConstruction * c)337   explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
338       : CollectiveOpV1Kernel(c) {
339     col_params_->instance.type = BROADCAST_COLLECTIVE;
340     OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
341     OP_REQUIRES(
342         c, col_params_->group.group_size > 0,
343         errors::InvalidArgument("group_size must be positive integer but got ",
344                                 col_params_->group.group_size));
345     OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
346     OP_REQUIRES_OK(
347         c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
348     OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
349     OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
350     OP_REQUIRES_OK(
351         c, c->GetAttr("communication_hint",
352                       &col_params_->instance.impl_details.communication_hint));
353     OP_REQUIRES_OK(
354         c, c->GetAttr("timeout_seconds",
355                       &col_params_->instance.impl_details.timeout_seconds));
356     col_params_->is_source = true;
357     col_params_->instance.impl_details.subdiv_offsets = {0};
358 
359     col_params_->name =
360         strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
361     col_params_->group.device_type = c->device_type();
362   }
363 
364  protected:
ComputeAsyncImpl(OpKernelContext * c,CollectiveExecutor * col_exec,DoneCallback done)365   void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
366                         DoneCallback done) override {
367     // Allocate output on the first pass through this function.  This must be
368     // done immediately, while we're still in the executor thread.  Otherwise
369     // the memory is not guaranteed to be unused by any concurrently executing
370     // GPU kernel.
371     if (c->mutable_output(0) == nullptr) {
372       // Allocate the output tensor, trying to reuse the input.
373       Tensor* output = nullptr;
374       OP_REQUIRES_OK_ASYNC(c,
375                            c->forward_input_or_allocate_output(
376                                {0}, 0, col_params_->instance.shape, &output),
377                            done);
378     }
379     if (!CanProceedWithCompute(c, col_exec, done)) return;
380     OP_REQUIRES_ASYNC(
381         c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
382         errors::Internal("Declared shape of op ", col_params_->name,
383                          " does not match shape of input"),
384         done);
385 
386     auto actual_done = [c, col_params = col_params_, done](const Status& s) {
387       VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
388               << c->op_kernel().name() << " device " << c->device()->name()
389               << " group " << col_params->group.group_key << " instance "
390               << col_params->instance.instance_key << " status " << s;
391       col_params->Unref();
392       OP_REQUIRES_OK_ASYNC(c, s, done);
393       done();
394     };
395     VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
396             << col_params_->name << " device " << c->device()->name()
397             << " group " << col_params_->group.group_key << " instance "
398             << col_params_->instance.instance_key;
399     col_params_->Ref();
400     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
401   }
402 
403  private:
404   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
405 };
406 
407 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
408                         CollectiveBcastSendOpKernel);
409 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_DEFAULT),
410                         CollectiveBcastSendOpKernel);
411 
412 class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
413  public:
CollectiveBcastRecvOpKernel(OpKernelConstruction * c)414   explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
415       : CollectiveOpV1Kernel(c) {
416     col_params_->instance.type = BROADCAST_COLLECTIVE;
417     OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
418     OP_REQUIRES(
419         c, col_params_->group.group_size > 0,
420         errors::InvalidArgument("group_size must be positive integer but got ",
421                                 col_params_->group.group_size));
422     OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
423     OP_REQUIRES_OK(
424         c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
425     OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
426     OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
427     OP_REQUIRES_OK(
428         c, c->GetAttr("communication_hint",
429                       &col_params_->instance.impl_details.communication_hint));
430     OP_REQUIRES_OK(
431         c, c->GetAttr("timeout_seconds",
432                       &col_params_->instance.impl_details.timeout_seconds));
433     col_params_->is_source = false;
434     col_params_->instance.impl_details.subdiv_offsets = {0};
435 
436     col_params_->name =
437         strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
438     col_params_->group.device_type = c->device_type();
439   }
440 
441  protected:
ComputeAsyncImpl(OpKernelContext * c,CollectiveExecutor * col_exec,DoneCallback done)442   void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
443                         DoneCallback done) override {
444     // Allocate output on the first pass through this function.  This must be
445     // done immediately, while we're still in the executor thread.  Otherwise
446     // the memory is not guaranteed to be unused by any concurrently executing
447     // GPU kernel.
448     if (c->mutable_output(0) == nullptr) {
449       // No input, so must allocate output.
450       Tensor* output = nullptr;
451       OP_REQUIRES_OK_ASYNC(
452           c, c->allocate_output(0, col_params_->instance.shape, &output), done);
453     }
454     if (!CanProceedWithCompute(c, col_exec, done)) return;
455 
456     auto actual_done = [c, col_params = col_params_, done](const Status& s) {
457       VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
458               << c->op_kernel().name() << " device " << c->device()->name()
459               << " group " << col_params->group.group_key << " instance_key "
460               << col_params->instance.instance_key << " status  " << s;
461       col_params->Unref();
462       OP_REQUIRES_OK_ASYNC(c, s, done);
463       done();
464     };
465     VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
466             << col_params_->name << " device " << c->device()->name()
467             << " group " << col_params_->group.group_key << " instance "
468             << col_params_->instance.instance_key;
469     col_params_->Ref();
470     col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
471   }
472 
473  private:
474   TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
475 };
476 
477 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
478                         CollectiveBcastRecvOpKernel);
479 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_DEFAULT),
480                         CollectiveBcastRecvOpKernel);
481 
482 class CollectiveAssignGroupV2OpKernel : public OpKernel {
483  public:
CollectiveAssignGroupV2OpKernel(OpKernelConstruction * c)484   explicit CollectiveAssignGroupV2OpKernel(OpKernelConstruction* c)
485       : OpKernel(c) {}
486 
Compute(OpKernelContext * context)487   void Compute(OpKernelContext* context) override {
488     const Tensor& group_assignment = context->input(0);
489     const Tensor& device_index = context->input(1);
490     const Tensor& base_key = context->input(2);
491 
492     OP_REQUIRES(
493         context, TensorShapeUtils::IsScalar(device_index.shape()),
494         errors::InvalidArgument(
495             "device_index must be a scalar, but received tensor of shape: ",
496             device_index.shape().DebugString()));
497 
498     OP_REQUIRES(
499         context, TensorShapeUtils::IsMatrix(group_assignment.shape()),
500         errors::InvalidArgument("group_assignment must be a 2-d Tensor, but "
501                                 "received tensor of shape: ",
502                                 group_assignment.shape().DebugString()));
503     OP_REQUIRES(context, TensorShapeUtils::IsScalar(base_key.shape()),
504                 errors::InvalidArgument(
505                     "base_key must be a scalar, but received tensor of shape: ",
506                     base_key.shape().DebugString()));
507 
508     Tensor* group_key = nullptr;
509     Tensor* group_size = nullptr;
510     AllocatorAttributes attr;
511     attr.set_on_host(true);
512     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
513                                                      &group_size, attr));
514 
515     OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
516                                                      &group_key, attr));
517 
518     OP_REQUIRES_OK(
519         context,
520         ComputeGroupKey(group_assignment, device_index.scalar<int32_t>()(),
521                         base_key.scalar<int32_t>()(), group_size, group_key));
522   }
523 
524  private:
ComputeGroupKey(const Tensor & group_assignment,const int32_t device_index,const int32_t base_key,Tensor * group_size,Tensor * group_key)525   static Status ComputeGroupKey(const Tensor& group_assignment,
526                                 const int32_t device_index,
527                                 const int32_t base_key, Tensor* group_size,
528                                 Tensor* group_key) {
529     group_size->flat<int32_t>()(0) = group_assignment.dim_size(1);
530 
531     for (int group_id = 0; group_id < group_assignment.dim_size(0);
532          group_id++) {
533       int32_t key = static_cast<int32_t>(static_cast<uint32_t>(base_key) +
534                                          static_cast<uint32_t>(group_id));
535       if (key == 0) {
536         return errors::InvalidArgument(
537             "Using the reserved group_key = 0 is not allowed: group_id = ",
538             group_id, ", base_key = ", base_key);
539       }
540       for (int color = 0; color < group_assignment.dim_size(1); color++) {
541         const auto index = group_assignment.matrix<int32>()(group_id, color);
542         if (index < 0 || index >= group_assignment.shape().num_elements()) {
543           return errors::InvalidArgument("Not all items in group_assignment ",
544                                          group_assignment.DebugString(),
545                                          " is within [0, number of devices)");
546         }
547         if (index == device_index) {
548           group_key->flat<int32_t>()(0) = key;
549           VLOG(2) << " group_assignment = " << group_assignment.DebugString()
550                   << " device_index = " << index
551                   << " group_key = " << group_key->DebugString()
552                   << " group_size = " << group_size->DebugString();
553           return OkStatus();
554         }
555       }
556     }
557     return errors::InvalidArgument("device_index ", device_index,
558                                    " is not found in group_assignment ",
559                                    group_assignment.DebugString());
560   }
561 };
562 
563 REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2").Device(DEVICE_CPU),
564                         CollectiveAssignGroupV2OpKernel);
565 REGISTER_KERNEL_BUILDER(Name("CollectiveAssignGroupV2")
566                             .Device(DEVICE_DEFAULT)
567                             .HostMemory("device_index")
568                             .HostMemory("group_assignment")
569                             .HostMemory("base_key")
570                             .HostMemory("group_size")
571                             .HostMemory("group_key"),
572                         CollectiveAssignGroupV2OpKernel);
573 
574 class CollectiveOpV2Kernel : public AsyncOpKernel {
575  public:
CollectiveOpV2Kernel(OpKernelConstruction * c)576   explicit CollectiveOpV2Kernel(OpKernelConstruction* c)
577       : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
578     OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
579     OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
580     OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
581     device_type_ = c->device_type();
582   }
583 
584  protected:
585   // Fills common parts of CollectiveParams according to the Op, *excluding
586   // output_shape*. Kernels should further work on the CollectiveParams if they
587   // need to set additional fields.
FillCollectiveParams(CollectiveParams * col_params,CollectiveType collective_type,const Tensor & group_size,const Tensor & group_key,const Tensor & instance_key)588   Status FillCollectiveParams(CollectiveParams* col_params,
589                               CollectiveType collective_type,
590                               const Tensor& group_size, const Tensor& group_key,
591                               const Tensor& instance_key) {
592     if (group_size.dims() > 0) {
593       return errors::InvalidArgument(
594           "Unexpected dimensions on input group_size, got ",
595           group_size.shape().DebugString());
596     }
597     if (group_key.dims() > 0) {
598       return errors::InvalidArgument(
599           "Unexpected dimensions on input group_key, got ",
600           group_key.shape().DebugString());
601     }
602     if (instance_key.dims() > 0) {
603       return errors::InvalidArgument(
604           "Unexpected dimensions on input instance_key, got ",
605           instance_key.shape().DebugString());
606     }
607     col_params->name = name_;
608     col_params->group.device_type = device_type_;
609     col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
610     if (col_params->group.group_size <= 0) {
611       return errors::InvalidArgument(
612           "group_size must be positive integer but got ",
613           col_params->group.group_size);
614     }
615     col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
616     col_params->instance.type = collective_type;
617     col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
618     col_params->instance.data_type = data_type_;
619     col_params->instance.impl_details.communication_hint = communication_hint_;
620     col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
621     return OkStatus();
622   }
623 
624   // Runs a collective. The output tensor must be allocated before calling this
625   // method. col_params must live until done is called.
Run(OpKernelContext * c,CollectiveParams * col_params,DoneCallback done)626   void Run(OpKernelContext* c, CollectiveParams* col_params,
627            DoneCallback done) {
628     CollectiveExecutor* col_exec = c->collective_executor();
629     OP_REQUIRES_ASYNC(
630         c, col_exec,
631         errors::Internal(
632             "Failed to get CollectiveExecutor from OpKernelContext for Op ",
633             name_),
634         done);
635     // Resolve the collective params.
636     // Schedule the `CompleteParamsAsync` call on a work queue that can handle
637     // blocking work because it's not guaranteed that this call cannot block.
638     c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
639                                           col_exec]() {
640       VLOG(1) << "Collective CompleteParams for " << col_params->name
641               << " device " << c->device()->name() << " group "
642               << col_params->group.group_key << " instance "
643               << col_params->instance.instance_key;
644       col_exec->CompleteParamsAsync(
645           c->device()->attributes(), col_params, c->cancellation_manager(),
646           [c, done = std::move(done), col_params, col_exec](const Status& s) {
647             if (s.ok()) {
648               auto actual_done = [c, col_params,
649                                   done = std::move(done)](const Status& s) {
650                 VLOG(1) << "Collective ExecuteAsync done for "
651                         << col_params->name << " device " << c->device()->name()
652                         << " group " << col_params->group.group_key
653                         << " instance " << col_params->instance.instance_key
654                         << " status " << s;
655                 if (!s.ok()) {
656                   c->SetStatus(s);
657                 }
658                 done();
659               };
660               VLOG(1) << "Collective ExecuteAsync start for "
661                       << col_params->name << " device " << c->device()->name()
662                       << " group " << col_params->group.group_key
663                       << " instance " << col_params->instance.instance_key;
664               col_exec->ExecuteAsync(
665                   c, col_params,
666                   CollectiveKey(c, col_params->group.group_key,
667                                 col_params->instance.instance_key),
668                   actual_done);
669             } else {
670               c->SetStatus(s);
671               done();
672             }
673           });
674     });
675   }
676 
677  protected:
678   string name_;
679   DataType data_type_ = DT_INVALID;
680   string communication_hint_;
681   float timeout_seconds_ = 0;
682   DeviceType device_type_;
683 };
684 
685 class CollectiveReduceV2OpKernel : public CollectiveOpV2Kernel {
686  public:
CollectiveReduceV2OpKernel(OpKernelConstruction * c)687   explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
688       : CollectiveOpV2Kernel(c) {
689     string merge_op_name;
690     OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
691     if (merge_op_name == "Max") {
692       merge_op_name = "Maximum";
693     } else if (merge_op_name == "Min") {
694       merge_op_name = "Minimum";
695     }
696     string final_op_name;
697     OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
698     OP_REQUIRES_OK(
699         c, c->GetAttr("max_subdivs_per_device", &max_subdivs_per_device_));
700     // Prepare OpKernels for reduction and final operations.
701     // The merge_op takes two inputs
702     NodeDef sub_node;
703     sub_node.add_input(c->def().input(0));
704     sub_node.add_input(c->def().input(0));
705     sub_node.set_device(c->def().device());
706     SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
707     merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
708     final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
709     name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",",
710                             final_op_name, ")");
711     VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_
712             << " communication_hint " << communication_hint_;
713   }
714 
ComputeAsync(OpKernelContext * c,DoneCallback done)715   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
716     auto col_params = new CollectiveParams();
717     auto done_with_cleanup = [col_params, done = std::move(done)]() {
718       done();
719       col_params->Unref();
720     };
721     OP_REQUIRES_OK_ASYNC(c,
722                          FillCollectiveParams(col_params, REDUCTION_COLLECTIVE,
723                                               /*group_size*/ c->input(1),
724                                               /*group_key*/ c->input(2),
725                                               /*instance_key*/ c->input(3)),
726                          done_with_cleanup);
727     col_params->instance.shape = c->input(0).shape();
728     col_params->merge_op = merge_op_.get();
729     col_params->final_op = final_op_.get();
730     VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
731             << " group_key " << col_params->group.group_key << " instance_key "
732             << col_params->instance.instance_key;
733     // Allocate the output tensor, trying to reuse the input.
734     Tensor* output = nullptr;
735     OP_REQUIRES_OK_ASYNC(c,
736                          c->forward_input_or_allocate_output(
737                              {0}, 0, col_params->instance.shape, &output),
738                          done_with_cleanup);
739     Run(c, col_params, std::move(done_with_cleanup));
740   }
741 
742  private:
743   int max_subdivs_per_device_;
744   std::unique_ptr<OpKernel> merge_op_;
745   std::unique_ptr<OpKernel> final_op_;
746 };
747 
748 REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
749                         CollectiveReduceV2OpKernel);
750 REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
751                             .Device(DEVICE_DEFAULT)
752                             .HostMemory("group_size")
753                             .HostMemory("group_key")
754                             .HostMemory("instance_key"),
755                         CollectiveReduceV2OpKernel);
756 
757 class CollectiveGatherV2OpKernel : public CollectiveOpV2Kernel {
758  public:
CollectiveGatherV2OpKernel(OpKernelConstruction * c)759   explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
760       : CollectiveOpV2Kernel(c) {
761     name_ = strings::StrCat(c->def().name(), ": GatherV2");
762     VLOG(2) << "CollectiveGatherV2 " << this << " name " << name_
763             << " communication_hint " << communication_hint_;
764   }
765 
ComputeAsync(OpKernelContext * c,DoneCallback done)766   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
767     auto col_params = new CollectiveParams();
768     auto done_with_cleanup = [col_params, done = std::move(done)]() {
769       done();
770       col_params->Unref();
771     };
772     OP_REQUIRES_OK_ASYNC(c,
773                          FillCollectiveParams(col_params, GATHER_COLLECTIVE,
774                                               /*group_size*/ c->input(1),
775                                               /*group_key*/ c->input(2),
776                                               /*instance_key*/
777                                               c->input(3)),
778                          done_with_cleanup);
779     auto output_shape = c->input(0).shape();
780     output_shape.set_dim(
781         0, output_shape.dim_size(0) * col_params->group.group_size);
782     col_params->instance.shape = output_shape;
783     VLOG(1) << "CollectiveGatherV2 group_size " << col_params->group.group_size
784             << " group_key " << col_params->group.group_key << " instance_key "
785             << col_params->instance.instance_key;
786     Tensor* output = nullptr;
787     OP_REQUIRES_OK_ASYNC(
788         c, c->allocate_output(0, col_params->instance.shape, &output),
789         done_with_cleanup);
790     Run(c, col_params, std::move(done_with_cleanup));
791   }
792 };
793 
794 REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),
795                         CollectiveGatherV2OpKernel);
796 REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
797                             .Device(DEVICE_DEFAULT)
798                             .HostMemory("group_size")
799                             .HostMemory("group_key")
800                             .HostMemory("instance_key"),
801                         CollectiveGatherV2OpKernel);
802 
803 class CollectiveBcastSendV2OpKernel : public CollectiveOpV2Kernel {
804  public:
CollectiveBcastSendV2OpKernel(OpKernelConstruction * c)805   explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
806       : CollectiveOpV2Kernel(c) {
807     const bool is_source = true;
808     name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
809   }
810 
811  protected:
ComputeAsync(OpKernelContext * c,DoneCallback done)812   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
813     auto col_params = new CollectiveParams();
814     auto done_with_cleanup = [col_params, done = std::move(done)]() {
815       done();
816       col_params->Unref();
817     };
818     OP_REQUIRES_OK_ASYNC(c,
819                          FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
820                                               /*group_size*/ c->input(1),
821                                               /*group_key*/ c->input(2),
822                                               /*instance_key*/ c->input(3)),
823                          done_with_cleanup);
824     col_params->is_source = true;
825     col_params->instance.shape = c->input(0).shape();
826     // Add a default value for subdiv offsets, which is the same as the default
827     // value in the V1 op's attribute.
828     col_params->instance.impl_details.subdiv_offsets.push_back(0);
829     VLOG(1) << "CollectiveBcastSendV2 group_size "
830             << col_params->group.group_size << " group_key "
831             << col_params->group.group_key << " instance_key "
832             << col_params->instance.instance_key;
833     // Allocate the output tensor, trying to reuse the input.
834     Tensor* output = nullptr;
835     OP_REQUIRES_OK_ASYNC(c,
836                          c->forward_input_or_allocate_output(
837                              {0}, 0, col_params->instance.shape, &output),
838                          done_with_cleanup);
839     Run(c, col_params, std::move(done_with_cleanup));
840   }
841 };
842 
843 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
844                         CollectiveBcastSendV2OpKernel);
845 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
846                             .Device(DEVICE_DEFAULT)
847                             .HostMemory("group_size")
848                             .HostMemory("group_key")
849                             .HostMemory("instance_key"),
850                         CollectiveBcastSendV2OpKernel);
851 
852 class CollectiveBcastRecvV2OpKernel : public CollectiveOpV2Kernel {
853  public:
CollectiveBcastRecvV2OpKernel(OpKernelConstruction * c)854   explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
855       : CollectiveOpV2Kernel(c) {
856     const bool is_source = false;
857     name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
858   }
859 
860  protected:
ComputeAsync(OpKernelContext * c,DoneCallback done)861   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
862     auto col_params = new CollectiveParams();
863     auto done_with_cleanup = [col_params, done = std::move(done)]() {
864       done();
865       col_params->Unref();
866     };
867     OP_REQUIRES_OK_ASYNC(c,
868                          FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
869                                               /*group_size*/ c->input(0),
870                                               /*group_key*/ c->input(1),
871                                               /*instance_key*/ c->input(2)),
872                          done_with_cleanup);
873     col_params->is_source = false;
874     TensorShape output_shape;
875     OP_REQUIRES_OK_ASYNC(c, tensor::MakeShape(c->input(3), &output_shape),
876                          done_with_cleanup);
877     col_params->instance.shape = output_shape;
878     // Add a default value for subdiv offsets, which is the same as the default
879     // value in the V1 op's attribute.
880     col_params->instance.impl_details.subdiv_offsets.push_back(0);
881     VLOG(1) << "CollectiveBcastRecvV2 group_size "
882             << col_params->group.group_size << " group_key "
883             << col_params->group.group_key << " instance_key "
884             << col_params->instance.instance_key;
885     Tensor* output = nullptr;
886     OP_REQUIRES_OK_ASYNC(
887         c, c->allocate_output(0, col_params->instance.shape, &output),
888         done_with_cleanup);
889     Run(c, col_params, std::move(done_with_cleanup));
890   }
891 };
892 
893 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
894                         CollectiveBcastRecvV2OpKernel);
895 REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
896                             .Device(DEVICE_DEFAULT)
897                             .HostMemory("group_size")
898                             .HostMemory("group_key")
899                             .HostMemory("instance_key")
900                             .HostMemory("shape"),
901                         CollectiveBcastRecvV2OpKernel);
902 
903 /*
904  * Resource for holding group for CollectiveOps.
905  * This resource is returned from CollectiveInitializeCommunicatorOpKernel
906  * It generates next instance key for the group for each collective operation.
907  */
908 class CollectiveGroupResource : public ResourceBase {
909  public:
CollectiveGroupResource(int32 group_key,int32 rank,int32 group_size,string communication_hint,float timeout_seconds)910   CollectiveGroupResource(int32 group_key, int32 rank, int32 group_size,
911                           string communication_hint, float timeout_seconds)
912       : group_key_(group_key),
913         rank_(rank),
914         group_size_(group_size),
915         communication_hint_(communication_hint),
916         timeout_seconds_(timeout_seconds) {}
917 
DebugString() const918   std::string DebugString() const override {
919     return absl::StrFormat(
920         "Collective Group with group_key = %d, group_size = %d, rank = %d",
921         group_key_, group_size_, rank_);
922   }
923 
get_next_instance_key()924   int get_next_instance_key() {
925     return instance_key_.fetch_add(1, std::memory_order_relaxed);
926   }
927 
group_key() const928   int32 group_key() const { return group_key_; }
929 
rank() const930   int32 rank() const { return rank_; }
931 
group_size() const932   int32 group_size() const { return group_size_; }
933 
communication_hint() const934   string communication_hint() const { return communication_hint_; }
935 
timeout_seconds() const936   float timeout_seconds() const { return timeout_seconds_; }
937 
938  private:
939   int32 group_key_, rank_, group_size_;
940   string communication_hint_;
941   std::atomic<int> instance_key_{0};
942   float timeout_seconds_ = 0;
943 };
944 
945 class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel {
946  public:
CollectiveInitializeCommunicatorOpKernel(OpKernelConstruction * c)947   explicit CollectiveInitializeCommunicatorOpKernel(OpKernelConstruction* c)
948       : AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
949     OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
950     OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
951     device_type_ = c->device_type();
952   }
953 
CheckInputs(Tensor group_size_t,Tensor group_key_t)954   Status CheckInputs(Tensor group_size_t, Tensor group_key_t) {
955     if (group_size_t.dims() > 0) {
956       return errors::InvalidArgument(
957           "Unexpected dimensions on input group_size. "
958           "It shoulbe a scalar, got tensor with shape ",
959           group_size_t.shape().DebugString());
960     }
961     if (group_key_t.dims() > 0) {
962       return errors::InvalidArgument(
963           "Unexpected dimensions on input group_key, got ",
964           group_key_t.shape().DebugString());
965     }
966 
967     auto group_size = group_size_t.unaligned_flat<int32>()(0);
968     if (group_size <= 0) {
969       return errors::InvalidArgument(
970           "group_size must be positive integer but got ", group_size);
971     }
972     return OkStatus();
973   }
974 
ComputeAsync(OpKernelContext * c,DoneCallback done)975   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
976     auto group_key_t = c->input(0);
977     auto rank_t = c->input(1);
978     auto group_size_t = c->input(2);
979 
980     OP_REQUIRES_OK_ASYNC(c, CheckInputs(group_size_t, group_key_t), done);
981 
982     auto group_size = group_size_t.unaligned_flat<int32>()(0);
983     auto group_key = group_key_t.unaligned_flat<int32>()(0);
984     auto rank = rank_t.unaligned_flat<int32>()(0);
985 
986     ResourceHandle resource_handle =
987         MakeResourceHandle<CollectiveGroupResource>(
988             c, "collective_op_group",
989             absl::StrFormat("%d:r%04d", group_key, rank));
990 
991     Tensor* output_handle = nullptr;
992     OP_REQUIRES_OK_ASYNC(
993         c, c->allocate_output(0, TensorShape({}), &output_handle), done);
994     output_handle->scalar<ResourceHandle>()() = resource_handle;
995 
996     CollectiveGroupResource* resource = new CollectiveGroupResource(
997         group_key, rank, group_size, this->communication_hint_,
998         this->timeout_seconds_);
999     OP_REQUIRES_OK_ASYNC(
1000         c,
1001         CreateResource<CollectiveGroupResource>(c, resource_handle, resource),
1002         done);
1003     auto group_params = new CollGroupParams();
1004     group_params->device_type = device_type_;
1005     group_params->group_size = resource->group_size();
1006     group_params->group_key = resource->group_key();
1007     group_params->user_specified_rank = resource->rank();
1008 
1009     auto* col_exec = c->collective_executor();
1010 
1011     c->collective_executor()->RunClosure([c, done = std::move(done),
1012                                           group_params, col_exec]() {
1013       VLOG(1) << "Collective Group initialization for "
1014               << " device " << c->device()->name() << " group "
1015               << group_params->group_key;
1016       col_exec->CompleteGroupAsync(
1017           c->device()->attributes(), group_params, c->cancellation_manager(),
1018           [c, done = std::move(done), group_params](const Status& s) {
1019             if (s.ok()) {
1020               VLOG(1) << "Collective Group initialization done for device "
1021                       << c->device()->name() << " group "
1022                       << group_params->group_key << " status " << s;
1023             } else {
1024               c->SetStatus(s);
1025             }
1026             delete group_params;
1027             done();
1028           });
1029     });
1030   }
1031 
1032  private:
1033   string communication_hint_;
1034   DeviceType device_type_;
1035   float timeout_seconds_ = 0;
1036 };
1037 
1038 REGISTER_KERNEL_BUILDER(
1039     Name("CollectiveInitializeCommunicator").Device(DEVICE_CPU),
1040     CollectiveInitializeCommunicatorOpKernel);
1041 REGISTER_KERNEL_BUILDER(Name("CollectiveInitializeCommunicator")
1042                             .Device(DEVICE_GPU)
1043                             .HostMemory("group_size")
1044                             .HostMemory("group_key")
1045                             .HostMemory("rank"),
1046                         CollectiveInitializeCommunicatorOpKernel);
1047 
1048 class CollectiveOpV3Kernel : public AsyncOpKernel {
1049  public:
CollectiveOpV3Kernel(OpKernelConstruction * c)1050   explicit CollectiveOpV3Kernel(OpKernelConstruction* c)
1051       : AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
1052     OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
1053     if (c->HasAttr("timeout_seconds")) {
1054       OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
1055     } else {
1056       timeout_seconds_ = -1;
1057     }
1058     device_type_ = c->device_type();
1059   }
1060 
1061  protected:
1062   // Fills common parts of CollectiveParams according to the Op, *excluding
1063   // output_shape*. Kernels should further work on the CollectiveParams if they
1064   // need to set additional fields.
FillCollectiveParams(CollectiveParams * col_params,const Tensor & group_assignment,CollectiveType collective_type,CollectiveGroupResource * resource)1065   Status FillCollectiveParams(CollectiveParams* col_params,
1066                               const Tensor& group_assignment,
1067                               CollectiveType collective_type,
1068                               CollectiveGroupResource* resource) {
1069     int64 group_id;
1070     int64 group_size;
1071     if (group_assignment.NumElements() == 0) {
1072       // No group assignments, perform collective as a single group.
1073       group_id = 0;
1074       group_size = resource->group_size();
1075     } else {
1076       return errors::Unimplemented("Group assignments are not supported yet.");
1077     }
1078 
1079     // Construct instance key with format:
1080     // <11 bits for group><21 bits for atomic incremented instance key>
1081     int32 instance_key = group_id << 21 | resource->get_next_instance_key();
1082     col_params->name = name_;
1083     col_params->group.device_type = device_type_;
1084     col_params->group.group_size = group_size;
1085     col_params->group.group_key = resource->group_key();
1086     col_params->group.user_specified_rank = resource->rank();
1087     col_params->instance.type = collective_type;
1088     col_params->instance.instance_key = instance_key;
1089     col_params->instance.data_type = data_type_;
1090     col_params->instance.impl_details.communication_hint =
1091         resource->communication_hint();
1092     col_params->instance.impl_details.timeout_seconds =
1093         timeout_seconds_ > 0 ? resource->timeout_seconds() : timeout_seconds_;
1094     col_params->run_group_initialization = false;
1095     return OkStatus();
1096   }
1097 
1098   // Runs a collective. The output tensor must be allocated before calling this
1099   // method. col_params must live until done is called.
Run(OpKernelContext * c,CollectiveParams * col_params,DoneCallback done)1100   void Run(OpKernelContext* c, CollectiveParams* col_params,
1101            DoneCallback done) {
1102     CollectiveExecutor* col_exec = c->collective_executor();
1103     OP_REQUIRES_ASYNC(
1104         c, col_exec,
1105         errors::Internal(
1106             "Failed to get CollectiveExecutor from OpKernelContext for Op ",
1107             name_),
1108         done);
1109     // Resolve the collective params.
1110     // Schedule the `CompleteParamsAsync` call on a work queue that can handle
1111     // blocking work because it's not guaranteed that this call cannot block.
1112     col_exec->RunClosure([c, done = std::move(done), col_params, col_exec]() {
1113       VLOG(1) << "Collective CompleteParams for " << col_params->name
1114               << " device " << c->device()->name() << " group "
1115               << col_params->group.group_key << " instance "
1116               << col_params->instance.instance_key;
1117       col_exec->CompleteParamsAsync(
1118           c->device()->attributes(), col_params, c->cancellation_manager(),
1119           [c, done = std::move(done), col_params, col_exec](const Status& s) {
1120             if (s.ok()) {
1121               auto actual_done = [c, col_params,
1122                                   done = std::move(done)](const Status& s) {
1123                 VLOG(1) << "Collective ExecuteAsync done for "
1124                         << col_params->name << " device " << c->device()->name()
1125                         << " group " << col_params->group.group_key
1126                         << " instance " << col_params->instance.instance_key
1127                         << " status " << s;
1128                 if (!s.ok()) {
1129                   c->SetStatus(s);
1130                 }
1131                 done();
1132               };
1133               VLOG(1) << "Collective ExecuteAsync start for "
1134                       << col_params->name << " device " << c->device()->name()
1135                       << " group " << col_params->group.group_key
1136                       << " instance " << col_params->instance.instance_key;
1137               col_exec->ExecuteAsync(
1138                   c, col_params,
1139                   CollectiveKey(c, col_params->group.group_key,
1140                                 col_params->instance.instance_key),
1141                   actual_done);
1142             } else {
1143               c->SetStatus(s);
1144               done();
1145             }
1146           });
1147     });
1148   }
1149 
1150  protected:
1151   string name_;
1152   DataType data_type_ = DT_INVALID;
1153   DeviceType device_type_;
1154   float timeout_seconds_ = 0;
1155 };
1156 
1157 class CollectiveReduceV3OpKernel : public CollectiveOpV3Kernel {
1158  public:
CollectiveReduceV3OpKernel(OpKernelConstruction * c)1159   explicit CollectiveReduceV3OpKernel(OpKernelConstruction* c)
1160       : CollectiveOpV3Kernel(c) {
1161     string reduction;
1162     OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
1163     if (reduction == "Max") {
1164       reduction = "Maximum";
1165     } else if (reduction == "Min") {
1166       reduction = "Minimum";
1167     }
1168     // Prepare OpKernels for reduction and final operations.
1169     // The merge_op takes two inputs
1170     NodeDef sub_node;
1171     sub_node.add_input(c->def().input(0));
1172     sub_node.add_input(c->def().input(0));
1173     sub_node.set_device(c->def().device());
1174     SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
1175     merge_op_ = BuildOpKernel(c, reduction, &sub_node);
1176     final_op_ = BuildOpKernel(c, "Id", &sub_node);
1177     name_ = strings::StrCat(c->def().name(), ": ReduceV3(", reduction, ")");
1178     VLOG(2) << "CollectiveReduceV3 " << this << " name " << name_;
1179   }
1180 
ComputeAsync(OpKernelContext * c,DoneCallback done)1181   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
1182     auto col_params = new CollectiveParams();
1183     auto done_with_cleanup = [col_params, done = std::move(done)]() {
1184       done();
1185       col_params->Unref();
1186     };
1187     core::RefCountPtr<CollectiveGroupResource> resource;
1188     OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
1189                          done_with_cleanup);
1190 
1191     Tensor group_assignment = c->input(2);
1192 
1193     OP_REQUIRES_OK_ASYNC(
1194         c,
1195         FillCollectiveParams(col_params, group_assignment, REDUCTION_COLLECTIVE,
1196                              resource.get()),
1197         done_with_cleanup);
1198     col_params->instance.shape = c->input(0).shape();
1199     col_params->merge_op = merge_op_.get();
1200     col_params->final_op = final_op_.get();
1201     VLOG(1) << "CollectiveReduceV3 group_size " << col_params->group.group_size
1202             << " group_key " << col_params->group.group_key << " instance_key "
1203             << col_params->instance.instance_key;
1204     // Allocate the output tensor, trying to reuse the input.
1205     Tensor* output = nullptr;
1206     OP_REQUIRES_OK_ASYNC(c,
1207                          c->forward_input_or_allocate_output(
1208                              {0}, 0, col_params->instance.shape, &output),
1209                          done_with_cleanup);
1210     Run(c, col_params, std::move(done_with_cleanup));
1211   }
1212 
1213  private:
1214   std::unique_ptr<OpKernel> merge_op_;
1215   std::unique_ptr<OpKernel> final_op_;
1216 };
1217 
1218 REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_CPU),
1219                         CollectiveReduceV3OpKernel);
1220 REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_GPU),
1221                         CollectiveReduceV3OpKernel);
1222 
1223 class CollectiveAllToAllV3OpKernel : public CollectiveOpV3Kernel {
1224  public:
CollectiveAllToAllV3OpKernel(OpKernelConstruction * c)1225   explicit CollectiveAllToAllV3OpKernel(OpKernelConstruction* c)
1226       : CollectiveOpV3Kernel(c) {
1227     name_ = strings::StrCat(c->def().name(), ": AllToAllV3");
1228     VLOG(2) << "CollectiveAllToAllV3 " << this << " name " << name_;
1229   }
1230 
ComputeAsync(OpKernelContext * c,DoneCallback done)1231   void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
1232     auto col_params = new CollectiveParams();
1233     auto done_with_cleanup = [col_params, done = std::move(done)]() {
1234       done();
1235       col_params->Unref();
1236     };
1237     core::RefCountPtr<CollectiveGroupResource> resource;
1238     OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
1239                          done_with_cleanup);
1240 
1241     Tensor group_assignment = c->input(2);
1242 
1243     OP_REQUIRES_OK_ASYNC(
1244         c,
1245         FillCollectiveParams(col_params, group_assignment,
1246                              ALL_TO_ALL_COLLECTIVE, resource.get()),
1247         done_with_cleanup);
1248     col_params->instance.shape = c->input(0).shape();
1249     VLOG(1) << "CollectiveAllToAll group_size " << col_params->group.group_size
1250             << " group_key " << col_params->group.group_key << " instance_key "
1251             << col_params->instance.instance_key;
1252     // Allocate the output tensor, trying to reuse the input.
1253     Tensor* output = nullptr;
1254     OP_REQUIRES_OK_ASYNC(c,
1255                          c->forward_input_or_allocate_output(
1256                              {0}, 0, col_params->instance.shape, &output),
1257                          done_with_cleanup);
1258     Run(c, col_params, std::move(done_with_cleanup));
1259   }
1260 };
1261 
1262 REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_CPU),
1263                         CollectiveAllToAllV3OpKernel);
1264 REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_GPU),
1265                         CollectiveAllToAllV3OpKernel);
1266 }  // namespace
1267 }  // namespace tensorflow
1268