xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/barrier_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16 
17 #include <limits.h>
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/framework/resource_op_kernel.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/priority_queue.h"
30 #include "tensorflow/core/kernels/queue_base.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/notification.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace tensorflow {
41 
42 namespace barrier {
43 
44 class Barrier : public ResourceBase {
45  public:
46   typedef std::vector<Tensor> Tuple;
47   typedef std::function<void()> DoneCallback;
48   typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)>
49       IndicesKeysValuesCallback;
50 
Barrier(const DataTypeVector & value_component_types,const std::vector<TensorShape> & value_component_shapes,const string & name)51   Barrier(const DataTypeVector& value_component_types,
52           const std::vector<TensorShape>& value_component_shapes,
53           const string& name)
54       : closed_(false),
55         queue_closed_(false),
56         queue_cancelled_(false),
57         cancel_pending_enqueues_(false),
58         value_component_types_(value_component_types),
59         value_component_shapes_(value_component_shapes),
60         name_(name),
61         input_index_(std::numeric_limits<int64_t>::min()) {
62     DataTypeVector queue_component_types;
63     std::vector<TensorShape> queue_component_shapes;
64 
65     // First queue component is for the input index;
66     // Second queue component is for the key;
67     // remaining queue components are for the value.
68     queue_component_types.push_back(DT_INT64);
69     queue_component_types.push_back(DT_STRING);
70     for (DataType dt : value_component_types) {
71       queue_component_types.push_back(dt);
72     }
73 
74     // NOTE(mrry): PriorityQueue expects all shapes specified because
75     // we'll be issuing TakeMany.
76     queue_component_shapes.push_back(TensorShape({}));
77     queue_component_shapes.push_back(TensorShape({}));
78     queue_component_shapes.insert(queue_component_shapes.end(),
79                                   value_component_shapes.begin(),
80                                   value_component_shapes.end());
81 
82     ready_queue_ = new PriorityQueue(
83         QueueBase::kUnbounded /* capacity */, queue_component_types,
84         queue_component_shapes, strings::StrCat(name_, "_queue"));
85   }
86 
Initialize()87   Status Initialize() { return ready_queue_->Initialize(); }
88 
89   template <typename T>
TryInsertMany(const Tensor & keys,int component_index,const Tensor & values,OpKernelContext * ctx,const DoneCallback & callback)90   void TryInsertMany(const Tensor& keys, int component_index,
91                      const Tensor& values, OpKernelContext* ctx,
92                      const DoneCallback& callback) {
93     TensorShape element_shape = values.shape();
94     OP_REQUIRES_ASYNC(
95         ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0,
96         errors::InvalidArgument("Tensors with no elements are not supported ",
97                                 name_, ": received shape ",
98                                 element_shape.DebugString()),
99         callback);
100     if (element_shape.dims() > 0) element_shape.RemoveDim(0);
101     const std::size_t num_inserted = keys.NumElements();
102 
103     // For each key, update the corresponding incomplete tuple with the
104     // the corresponding given value at component_index.
105     // This will be passed to the final callback at the very end.
106     bool new_elements = false;
107 
108     // Will be used for the final insert into the queue.
109     Tuple insert_tuple;
110 
111     {
112       mutex_lock lock(mu_);
113       if (closed_) {
114         OP_REQUIRES_ASYNC(
115             ctx,
116             !cancel_pending_enqueues_ &&
117                 (num_inserted == 0 || !incomplete_.empty()),
118             errors::Cancelled(
119                 "Barrier ", name_, " is closed.  Pending enqueues cancelled: ",
120                 cancel_pending_enqueues_,
121                 ".  Number of new insertions: ", num_inserted,
122                 ".  Number of incomplete keys: ", incomplete_.size(), "."),
123             callback);
124       }
125 
126       // Step 1: insert into the incomplete map and identify which
127       // entries are, in fact, complete and ready for enqueueing.  Store
128       // them in a vector
129       std::vector<Tuple> ready_tuples;
130 
131       for (int i = 0; i < num_inserted; ++i) {
132         OP_REQUIRES_OK_ASYNC(
133             ctx,
134             InsertOneLocked<T>(ctx, keys, values, element_shape,
135                                component_index, i, &ready_tuples,
136                                &new_elements),
137             callback);
138       }
139 
140       if (new_elements) ++input_index_;
141 
142       // This probably won't happen before the heat death of the
143       // universe, but who knows?  Moore's law FTW.
144       OP_REQUIRES_ASYNC(
145           ctx, input_index_ != std::numeric_limits<int64_t>::max(),
146           errors::Internal(
147               "Barrier has had ", input_index_,
148               " insertions and can no longer keep track of new ones."),
149           callback);
150 
151       if (ready_tuples.empty()) {
152         // Nothing to insert into the queue - so return early.
153         callback();
154         return;
155       }
156 
157       // We have something to Enqueue.  Convert the Tuples into a single
158       // tuple by slicing entries into new Tensors.  This part is slow
159       // but seems the cleanest solution for now.
160       insert_tuple.reserve(2 + num_components());  // indices, keys, rest
161       int insertion_size = ready_tuples.size();
162       for (int i = 0; i < 2 + num_components(); ++i) {
163         TensorShape component_shape(ready_tuples[0][i].shape());
164         component_shape.InsertDim(0, insertion_size);
165         Tensor component(ready_tuples[0][i].dtype(), component_shape);
166         for (int b = 0; b < insertion_size; ++b) {
167           OP_REQUIRES_OK_ASYNC(
168               ctx,
169               batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]),
170                                              &component, b),
171               callback);
172         }
173         insert_tuple.push_back(component);
174       }
175     }
176 
177     // Update the input index for the next batch.
178     ready_queue_->TryEnqueueMany(
179         insert_tuple, ctx,
180         // To avoid early closing of the queue, only close it if the
181         // SQSS is closed, nothing is left in the incomplete set,
182         // the queue is not already marked as closed, and (most
183         // importantly), the queue has entries in it.
184         [this, ctx, callback]() {
185           if (!ctx->status().ok()) {
186             callback();
187             return;
188           }
189           {
190             mutex_lock lock(mu_);
191             int32_t ready = ready_size();
192             if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) {
193               CloseQueueLocked(ctx, false, callback);
194             } else {
195               callback();
196             }
197             return;
198           }
199         });
200   }
201 
TryTakeMany(int num_elements,bool allow_small_batch,int64_t timeout,OpKernelContext * ctx,const IndicesKeysValuesCallback & callback)202   void TryTakeMany(int num_elements, bool allow_small_batch, int64_t timeout,
203                    OpKernelContext* ctx,
204                    const IndicesKeysValuesCallback& callback) {
205     int num_elements_to_deliver = num_elements;
206     {
207       mutex_lock lock(mu_);
208       if (closed_) {
209         int available_elements = ready_size();
210         if (allow_small_batch) {
211           // We want to deliver a maximum of num_elements, if there are less
212           // elements available, we deliver at most the available_elements. If
213           // there are no
214           // elements available, a call to TryTakeMany should fail with
215           // OutOfRange. We trigger this error by setting the request here to 1.
216           num_elements_to_deliver = std::min(num_elements, available_elements);
217         } else {
218           // We're happy to wait for additional elements to be completed.
219           available_elements += incomplete_.size();
220         }
221         // If there are 0 available elements or less elements than the
222         // number we can deliver, then we are done.
223         if (available_elements < std::max(num_elements_to_deliver, 1)) {
224           ctx->SetStatus(errors::OutOfRange(
225               "Barrier '", name_, "' is closed and has ",
226               "insufficient elements (requested ", num_elements_to_deliver,
227               ", total size ", available_elements, ")"));
228           callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple());
229           return;
230         }
231       }
232     }
233 
234     ready_queue_->TryDequeueMany(
235         num_elements_to_deliver, ctx, allow_small_batch,
236         [this, ctx, callback](const Tuple& t) {
237           Tensor indices(DT_INT64);
238           Tensor keys(DT_STRING);
239           Tuple values;
240 
241           if (!ctx->status().ok()) {
242             callback(indices, keys, values);
243             return;
244           }
245 
246           CHECK_EQ(t.size(), 2 + num_components());
247           indices = t[0];
248           keys = t[1];
249           values.insert(values.begin(), t.begin() + 2, t.end());
250           callback(indices, keys, values);
251         });
252   }
253 
Close(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)254   void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
255              const DoneCallback& callback) {
256     mutex_lock lock(mu_);
257     // We're allowed to close twice if the first close wasn't a
258     // cancel but the second one is.
259     if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) {
260       ctx->SetStatus(
261           errors::Cancelled("Barrier '", name_, "' is already closed."));
262       callback();
263       return;
264     }
265     cancel_pending_enqueues_ = cancel_pending_enqueues;
266     closed_ = true;
267     if (cancel_pending_enqueues_ || incomplete_.empty()) {
268       incomplete_.clear();
269       // CloseQueueLocked runs the callback
270       CloseQueueLocked(ctx, cancel_pending_enqueues_, callback);
271       return;
272     }
273     callback();
274   }
275 
ready_size()276   int32 ready_size() { return ready_queue_->size(); }
277 
incomplete_size()278   int32 incomplete_size() {
279     mutex_lock lock(mu_);
280     return incomplete_.size();
281   }
282 
name() const283   const string& name() const { return name_; }
num_components() const284   int num_components() const { return value_component_types_.size(); }
component_type(int i) const285   DataType component_type(int i) const {
286     CHECK_GE(i, 0);
287     CHECK_LT(static_cast<size_t>(i), value_component_types_.size());
288     return value_component_types_[i];
289   }
component_types() const290   const DataTypeVector component_types() const {
291     return value_component_types_;
292   }
component_shapes() const293   const gtl::ArraySlice<TensorShape> component_shapes() const {
294     return value_component_shapes_;
295   }
296 
~Barrier()297   ~Barrier() override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
298     mutex_lock lock(mu_);
299     incomplete_.clear();
300     ready_queue_->Unref();
301   }
302 
DebugString() const303   string DebugString() const override { return "A barrier"; }
304 
305  protected:
306   template <typename T>
InsertOneLocked(OpKernelContext * ctx,const Tensor & keys,const Tensor & values,const TensorShape & element_shape,int component_index,int i,std::vector<Tuple> * ready_tuples,bool * new_elements)307   Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys,
308                          const Tensor& values, const TensorShape& element_shape,
309                          int component_index, int i,
310                          std::vector<Tuple>* ready_tuples, bool* new_elements)
311       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
312     auto keys_vec = keys.flat<tstring>();
313     auto values_matrix = values.flat_outer_dims<T>();
314 
315     TensorTuple* element_ptr;
316     if (closed_) {
317       element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i));
318       if (element_ptr == nullptr) {
319         return errors::Cancelled(
320             "Barrier ", name_,
321             " is closed, but attempted to insert a brand new key: ",
322             keys_vec(i),
323             ".  Pending enqueues cancelled: ", cancel_pending_enqueues_,
324             ".  Insertion index: ", i,
325             ".  Number of incomplete keys: ", incomplete_.size(), ".");
326       }
327     } else {
328       element_ptr =
329           &gtl::LookupOrInsert(&incomplete_, keys_vec(i), TensorTuple());
330     }
331     TensorTuple& element = *element_ptr;
332 
333     if (element.empty()) {  // Never seen before key
334       // Added a new element, for keeping track of the insertion index
335       *new_elements = true;
336 
337       // Initialize the incomplete tuple for a new key.
338       element.reserve(1 + num_components());
339 
340       // The first entry in element is the priority: the
341       // input_index_, so that tensors that entered the Barrier
342       // earlier have higher priority in the queue.
343       Tensor allocate_index_tensor;
344       TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}),
345                                             &allocate_index_tensor));
346 
347       Tensor index_tensor(DT_INT64, TensorShape({}));
348       allocate_index_tensor.scalar<int64_t>()() = input_index_;
349       element.push_back(allocate_index_tensor);
350 
351       // The rest of the element stores uninitialized Tensors with
352       // the appropriate dtype.
353       for (int j = 0; j < num_components(); ++j) {
354         Tensor uninitialized(component_type(j));
355         element.push_back(Tensor(uninitialized));
356       }
357     }
358     const Tensor& component = element[1 + component_index];
359     if (component.IsInitialized() && component.NumElements() > 0) {
360       return errors::InvalidArgument("Key ", keys_vec(i),
361                                      " already has a value for component ",
362                                      component_index, " in barrier ", name());
363     }
364 
365     // Extract the slice corresponding to the value from the value Tensor,
366     // and store it in the incomplete tuple at component_index.
367     Tensor next_element;
368     TF_RETURN_IF_ERROR(
369         ctx->allocate_temp(values.dtype(), element_shape, &next_element));
370     element[1 + component_index] = next_element;
371     next_element.flat<T>() = values_matrix.template chip<0>(i);
372 
373     // Check the components of the tuple to see if it has become complete
374     // (i.e. all of its components are initialized). If so, add it to the
375     // ready queue.
376     bool is_complete = true;
377     for (int j = 0; is_complete && j < element.size(); ++j) {
378       is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
379     }
380     if (is_complete) {
381       // Add tuple to the ready queue. A queue tuple has the index
382       // as the first element and the key as the second element,
383       // followed by the value components.
384       Tuple ready_tuple;
385       ready_tuple.reserve(2 + num_components());  // index, key, rest
386       // Build a tensor for the key. TODO(mrry): Something more efficient.
387       Tensor key;
388       TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &key));
389       ready_tuple.push_back(element[0]);                 // index
390       ready_tuple.push_back(key);                        // key
391       ready_tuple[1].scalar<tstring>()() = keys_vec(i);  // set the key
392       for (int j = 1; j < num_components() + 1; ++j) {
393         ready_tuple.push_back(element[j]);
394       }
395       incomplete_.erase(incomplete_.find(keys_vec(i)));
396       TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple));
397       ready_tuples->push_back(ready_tuple);
398     }
399     return OkStatus();
400   }
401 
CloseQueueLocked(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)402   void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues,
403                         const DoneCallback& callback)
404       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
405     // CloseQueueLocked may only be called with mu_ held.
406     if (!cancel_pending_enqueues && queue_closed_) {
407       callback();
408       return;
409     }
410     if (cancel_pending_enqueues && queue_cancelled_) {
411       callback();
412       return;
413     }
414     queue_closed_ = true;
415     if (cancel_pending_enqueues) queue_cancelled_ = true;
416     if (!ready_queue_->is_closed()) {
417       ready_queue_->Close(ctx, cancel_pending_enqueues, callback);
418     }
419   }
420 
421  private:
422   typedef std::vector<Tensor> TensorTuple;
423   mutex mu_;
424   bool closed_ TF_GUARDED_BY(mu_);
425   bool queue_closed_ TF_GUARDED_BY(mu_);
426   bool queue_cancelled_ TF_GUARDED_BY(mu_);
427   bool cancel_pending_enqueues_ TF_GUARDED_BY(mu_);
428   const DataTypeVector value_component_types_;
429   const std::vector<TensorShape>& value_component_shapes_;
430   const string name_;
431   int64_t input_index_ TF_GUARDED_BY(mu_);
432   std::unordered_map<string, TensorTuple> incomplete_ TF_GUARDED_BY(mu_);
433   PriorityQueue* ready_queue_;
434 
435   TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
436 };
437 
438 class BarrierOp : public ResourceOpKernel<Barrier> {
439  public:
BarrierOp(OpKernelConstruction * context)440   explicit BarrierOp(OpKernelConstruction* context)
441       : ResourceOpKernel(context) {
442     OP_REQUIRES_OK(
443         context, context->GetAttr("component_types", &value_component_types_));
444     OP_REQUIRES_OK(context,
445                    context->GetAttr("shapes", &value_component_shapes_));
446     OP_REQUIRES(context,
447                 value_component_shapes_.size() == value_component_types_.size(),
448                 errors::InvalidArgument(
449                     "All of the component shapes must be specified"));
450 
451     int32_t value_capacity;
452     OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity));
453     OP_REQUIRES(context, value_capacity == -1,
454                 errors::InvalidArgument(
455                     "Barrier only accepts capacity=-1.  Feed the "
456                     "inputs to your Barrier through a queue to enforce a "
457                     "limited capacity."));
458   }
459 
460  private:
CreateResource(Barrier ** barrier)461   Status CreateResource(Barrier** barrier) override
462       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
463     *barrier = new Barrier(value_component_types_, value_component_shapes_,
464                            cinfo_.name());
465     if (*barrier == nullptr) {
466       return errors::ResourceExhausted("Failed to allocate barrier");
467     }
468     return (*barrier)->Initialize();
469   }
470 
VerifyResource(Barrier * barrier)471   Status VerifyResource(Barrier* barrier) override
472       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
473     if (barrier->component_types() != value_component_types_) {
474       return errors::InvalidArgument(
475           "Shared barrier '", cinfo_.name(), "' has component types ",
476           DataTypeSliceString(barrier->component_types()),
477           " but requested component types were ",
478           DataTypeSliceString(value_component_types_));
479     }
480     if (barrier->component_shapes() != value_component_shapes_) {
481       return errors::InvalidArgument(
482           "Shared barrier '", cinfo_.name(), "' has component shapes ",
483           TensorShapeUtils::ShapeListString(barrier->component_shapes()),
484           " but requested component shapes were ",
485           TensorShapeUtils::ShapeListString(value_component_shapes_));
486     }
487     return OkStatus();
488   }
489 
490   DataTypeVector value_component_types_;
491   std::vector<TensorShape> value_component_shapes_;
492 
493   TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
494 };
495 
496 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp);
497 
498 class BarrierOpKernel : public AsyncOpKernel {
499  public:
BarrierOpKernel(OpKernelConstruction * context)500   explicit BarrierOpKernel(OpKernelConstruction* context)
501       : AsyncOpKernel(context) {}
502 
ComputeAsync(OpKernelContext * ctx,DoneCallback callback)503   void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
504     Barrier* barrier = nullptr;
505     OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier),
506                          callback);
507     ComputeAsync(ctx, barrier, [callback, barrier]() {
508       barrier->Unref();
509       callback();
510     });
511   }
512 
513  protected:
514   virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
515                             DoneCallback callback) = 0;
516 };
517 
518 template <typename T>
519 class InsertManyOp : public BarrierOpKernel {
520  public:
InsertManyOp(OpKernelConstruction * context)521   explicit InsertManyOp(OpKernelConstruction* context)
522       : BarrierOpKernel(context) {
523     OP_REQUIRES_OK(context,
524                    context->GetAttr("component_index", &component_index_));
525   }
526 
527  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)528   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
529                     DoneCallback callback) override {
530     OP_REQUIRES_ASYNC(
531         ctx, component_index_ < barrier->num_components(),
532         errors::InvalidArgument("The component ID is out of range ",
533                                 component_index_, " > num_components",
534                                 " (= ", barrier->num_components(), ")"),
535         callback);
536     OP_REQUIRES_OK_ASYNC(
537         ctx,
538         ctx->MatchSignature({DT_STRING_REF, DT_STRING,
539                              barrier->component_type(component_index_)},
540                             {}),
541         callback);
542 
543     const Tensor* keys;
544     const Tensor* values;
545     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback);
546     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback);
547     barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback);
548   }
549 
550  private:
551   int component_index_;
552   TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp);
553 };
554 
555 #define REGISTER_INSERTMANY(T)                                             \
556   REGISTER_KERNEL_BUILDER(                                                 \
557       Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
558       InsertManyOp<T>);
559 
560 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY);
561 #undef REGISTER_INSERTMANY
562 
563 class TakeManyOp : public BarrierOpKernel {
564  public:
TakeManyOp(OpKernelConstruction * context)565   explicit TakeManyOp(OpKernelConstruction* context)
566       : BarrierOpKernel(context) {
567     OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
568     // TODO(keveman): Enable timeout.
569     OP_REQUIRES(context, timeout_ == -1,
570                 errors::InvalidArgument("Timeout not supported yet."));
571 
572     OP_REQUIRES_OK(context,
573                    context->GetAttr("allow_small_batch", &allow_small_batch_));
574   }
575 
576  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)577   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
578                     DoneCallback callback) override {
579     const Tensor* Tnum_elements;
580     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements),
581                          callback);
582     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()),
583                       errors::InvalidArgument("num_elements must be a scalar."),
584                       callback);
585     const int32_t num_elements = Tnum_elements->scalar<int32>()();
586 
587     DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32};
588     // The first output is the insertion index, the second output is the key.
589     DataTypeVector expected_outputs = {DT_INT64, DT_STRING};
590     for (DataType dt : barrier->component_types()) {
591       expected_outputs.push_back(dt);
592     }
593     OP_REQUIRES_OK_ASYNC(
594         ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback);
595 
596     barrier->TryTakeMany(
597         num_elements, allow_small_batch_, timeout_, ctx,
598         [ctx, callback](const Tensor& indices, const Tensor& keys,
599                         const Barrier::Tuple& values) {
600           if (!ctx->status().ok()) {
601             callback();
602             return;
603           }
604           // At this point, indices, keys, and values
605           // have all been written to successfully.
606           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices),
607                                callback);
608           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback);
609           OpOutputList values_output;
610           OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output),
611                                callback);
612           for (size_t i = 0; i < values.size(); ++i) {
613             values_output.set(i, values[i]);
614           }
615           callback();
616         });
617   }
618 
619  private:
620   int64_t timeout_;
621   bool allow_small_batch_;
622   TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp);
623 };
624 
625 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp);
626 
627 class BarrierCloseOp : public BarrierOpKernel {
628  public:
BarrierCloseOp(OpKernelConstruction * context)629   explicit BarrierCloseOp(OpKernelConstruction* context)
630       : BarrierOpKernel(context) {
631     OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
632                                              &cancel_pending_enqueues_));
633   }
634 
635  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)636   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
637                     DoneCallback callback) override {
638     barrier->Close(ctx, cancel_pending_enqueues_, callback);
639   }
640 
641  private:
642   bool cancel_pending_enqueues_;
643   TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp);
644 };
645 
646 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU),
647                         BarrierCloseOp);
648 
649 class BarrierIncompleteSizeOp : public BarrierOpKernel {
650  public:
BarrierIncompleteSizeOp(OpKernelConstruction * context)651   explicit BarrierIncompleteSizeOp(OpKernelConstruction* context)
652       : BarrierOpKernel(context) {}
653 
654  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)655   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
656                     DoneCallback callback) override {
657     Tensor* Tsize = nullptr;
658     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
659                          callback);
660     Tsize->scalar<int32>().setConstant(barrier->incomplete_size());
661     callback();
662   }
663 };
664 
665 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU),
666                         BarrierIncompleteSizeOp);
667 
668 class BarrierReadySizeOp : public BarrierOpKernel {
669  public:
BarrierReadySizeOp(OpKernelConstruction * context)670   explicit BarrierReadySizeOp(OpKernelConstruction* context)
671       : BarrierOpKernel(context) {}
672 
673  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)674   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
675                     DoneCallback callback) override {
676     Tensor* Tsize = nullptr;
677     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
678                          callback);
679     Tsize->scalar<int32>().setConstant(barrier->ready_size());
680     callback();
681   }
682 };
683 
684 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU),
685                         BarrierReadySizeOp);
686 
687 }  // namespace barrier
688 
689 }  // namespace tensorflow
690