1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // See docs in ../ops/data_flow_ops.cc. 16 17 #include <limits.h> 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/framework/resource_op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/priority_queue.h" 30 #include "tensorflow/core/kernels/queue_base.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/core/notification.h" 33 #include "tensorflow/core/lib/gtl/map_util.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/thread_annotations.h" 38 #include "tensorflow/core/platform/types.h" 39 40 namespace tensorflow { 41 42 namespace barrier { 43 44 class Barrier : public ResourceBase { 45 public: 46 typedef std::vector<Tensor> Tuple; 47 typedef std::function<void()> DoneCallback; 48 typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)> 49 IndicesKeysValuesCallback; 50 Barrier(const DataTypeVector & value_component_types,const std::vector<TensorShape> & value_component_shapes,const string & name)51 Barrier(const DataTypeVector& value_component_types, 52 const std::vector<TensorShape>& value_component_shapes, 53 const string& name) 54 : closed_(false), 55 queue_closed_(false), 56 queue_cancelled_(false), 57 cancel_pending_enqueues_(false), 58 value_component_types_(value_component_types), 59 value_component_shapes_(value_component_shapes), 60 name_(name), 61 input_index_(std::numeric_limits<int64_t>::min()) { 62 DataTypeVector queue_component_types; 63 std::vector<TensorShape> queue_component_shapes; 64 65 // First queue component is for the input index; 66 // Second queue component is for the key; 67 // remaining queue components are for the value. 68 queue_component_types.push_back(DT_INT64); 69 queue_component_types.push_back(DT_STRING); 70 for (DataType dt : value_component_types) { 71 queue_component_types.push_back(dt); 72 } 73 74 // NOTE(mrry): PriorityQueue expects all shapes specified because 75 // we'll be issuing TakeMany. 76 queue_component_shapes.push_back(TensorShape({})); 77 queue_component_shapes.push_back(TensorShape({})); 78 queue_component_shapes.insert(queue_component_shapes.end(), 79 value_component_shapes.begin(), 80 value_component_shapes.end()); 81 82 ready_queue_ = new PriorityQueue( 83 QueueBase::kUnbounded /* capacity */, queue_component_types, 84 queue_component_shapes, strings::StrCat(name_, "_queue")); 85 } 86 Initialize()87 Status Initialize() { return ready_queue_->Initialize(); } 88 89 template <typename T> TryInsertMany(const Tensor & keys,int component_index,const Tensor & values,OpKernelContext * ctx,const DoneCallback & callback)90 void TryInsertMany(const Tensor& keys, int component_index, 91 const Tensor& values, OpKernelContext* ctx, 92 const DoneCallback& callback) { 93 TensorShape element_shape = values.shape(); 94 OP_REQUIRES_ASYNC( 95 ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0, 96 errors::InvalidArgument("Tensors with no elements are not supported ", 97 name_, ": received shape ", 98 element_shape.DebugString()), 99 callback); 100 if (element_shape.dims() > 0) element_shape.RemoveDim(0); 101 const std::size_t num_inserted = keys.NumElements(); 102 103 // For each key, update the corresponding incomplete tuple with the 104 // the corresponding given value at component_index. 105 // This will be passed to the final callback at the very end. 106 bool new_elements = false; 107 108 // Will be used for the final insert into the queue. 109 Tuple insert_tuple; 110 111 { 112 mutex_lock lock(mu_); 113 if (closed_) { 114 OP_REQUIRES_ASYNC( 115 ctx, 116 !cancel_pending_enqueues_ && 117 (num_inserted == 0 || !incomplete_.empty()), 118 errors::Cancelled( 119 "Barrier ", name_, " is closed. Pending enqueues cancelled: ", 120 cancel_pending_enqueues_, 121 ". Number of new insertions: ", num_inserted, 122 ". Number of incomplete keys: ", incomplete_.size(), "."), 123 callback); 124 } 125 126 // Step 1: insert into the incomplete map and identify which 127 // entries are, in fact, complete and ready for enqueueing. Store 128 // them in a vector 129 std::vector<Tuple> ready_tuples; 130 131 for (int i = 0; i < num_inserted; ++i) { 132 OP_REQUIRES_OK_ASYNC( 133 ctx, 134 InsertOneLocked<T>(ctx, keys, values, element_shape, 135 component_index, i, &ready_tuples, 136 &new_elements), 137 callback); 138 } 139 140 if (new_elements) ++input_index_; 141 142 // This probably won't happen before the heat death of the 143 // universe, but who knows? Moore's law FTW. 144 OP_REQUIRES_ASYNC( 145 ctx, input_index_ != std::numeric_limits<int64_t>::max(), 146 errors::Internal( 147 "Barrier has had ", input_index_, 148 " insertions and can no longer keep track of new ones."), 149 callback); 150 151 if (ready_tuples.empty()) { 152 // Nothing to insert into the queue - so return early. 153 callback(); 154 return; 155 } 156 157 // We have something to Enqueue. Convert the Tuples into a single 158 // tuple by slicing entries into new Tensors. This part is slow 159 // but seems the cleanest solution for now. 160 insert_tuple.reserve(2 + num_components()); // indices, keys, rest 161 int insertion_size = ready_tuples.size(); 162 for (int i = 0; i < 2 + num_components(); ++i) { 163 TensorShape component_shape(ready_tuples[0][i].shape()); 164 component_shape.InsertDim(0, insertion_size); 165 Tensor component(ready_tuples[0][i].dtype(), component_shape); 166 for (int b = 0; b < insertion_size; ++b) { 167 OP_REQUIRES_OK_ASYNC( 168 ctx, 169 batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]), 170 &component, b), 171 callback); 172 } 173 insert_tuple.push_back(component); 174 } 175 } 176 177 // Update the input index for the next batch. 178 ready_queue_->TryEnqueueMany( 179 insert_tuple, ctx, 180 // To avoid early closing of the queue, only close it if the 181 // SQSS is closed, nothing is left in the incomplete set, 182 // the queue is not already marked as closed, and (most 183 // importantly), the queue has entries in it. 184 [this, ctx, callback]() { 185 if (!ctx->status().ok()) { 186 callback(); 187 return; 188 } 189 { 190 mutex_lock lock(mu_); 191 int32_t ready = ready_size(); 192 if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) { 193 CloseQueueLocked(ctx, false, callback); 194 } else { 195 callback(); 196 } 197 return; 198 } 199 }); 200 } 201 TryTakeMany(int num_elements,bool allow_small_batch,int64_t timeout,OpKernelContext * ctx,const IndicesKeysValuesCallback & callback)202 void TryTakeMany(int num_elements, bool allow_small_batch, int64_t timeout, 203 OpKernelContext* ctx, 204 const IndicesKeysValuesCallback& callback) { 205 int num_elements_to_deliver = num_elements; 206 { 207 mutex_lock lock(mu_); 208 if (closed_) { 209 int available_elements = ready_size(); 210 if (allow_small_batch) { 211 // We want to deliver a maximum of num_elements, if there are less 212 // elements available, we deliver at most the available_elements. If 213 // there are no 214 // elements available, a call to TryTakeMany should fail with 215 // OutOfRange. We trigger this error by setting the request here to 1. 216 num_elements_to_deliver = std::min(num_elements, available_elements); 217 } else { 218 // We're happy to wait for additional elements to be completed. 219 available_elements += incomplete_.size(); 220 } 221 // If there are 0 available elements or less elements than the 222 // number we can deliver, then we are done. 223 if (available_elements < std::max(num_elements_to_deliver, 1)) { 224 ctx->SetStatus(errors::OutOfRange( 225 "Barrier '", name_, "' is closed and has ", 226 "insufficient elements (requested ", num_elements_to_deliver, 227 ", total size ", available_elements, ")")); 228 callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple()); 229 return; 230 } 231 } 232 } 233 234 ready_queue_->TryDequeueMany( 235 num_elements_to_deliver, ctx, allow_small_batch, 236 [this, ctx, callback](const Tuple& t) { 237 Tensor indices(DT_INT64); 238 Tensor keys(DT_STRING); 239 Tuple values; 240 241 if (!ctx->status().ok()) { 242 callback(indices, keys, values); 243 return; 244 } 245 246 CHECK_EQ(t.size(), 2 + num_components()); 247 indices = t[0]; 248 keys = t[1]; 249 values.insert(values.begin(), t.begin() + 2, t.end()); 250 callback(indices, keys, values); 251 }); 252 } 253 Close(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)254 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 255 const DoneCallback& callback) { 256 mutex_lock lock(mu_); 257 // We're allowed to close twice if the first close wasn't a 258 // cancel but the second one is. 259 if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) { 260 ctx->SetStatus( 261 errors::Cancelled("Barrier '", name_, "' is already closed.")); 262 callback(); 263 return; 264 } 265 cancel_pending_enqueues_ = cancel_pending_enqueues; 266 closed_ = true; 267 if (cancel_pending_enqueues_ || incomplete_.empty()) { 268 incomplete_.clear(); 269 // CloseQueueLocked runs the callback 270 CloseQueueLocked(ctx, cancel_pending_enqueues_, callback); 271 return; 272 } 273 callback(); 274 } 275 ready_size()276 int32 ready_size() { return ready_queue_->size(); } 277 incomplete_size()278 int32 incomplete_size() { 279 mutex_lock lock(mu_); 280 return incomplete_.size(); 281 } 282 name() const283 const string& name() const { return name_; } num_components() const284 int num_components() const { return value_component_types_.size(); } component_type(int i) const285 DataType component_type(int i) const { 286 CHECK_GE(i, 0); 287 CHECK_LT(static_cast<size_t>(i), value_component_types_.size()); 288 return value_component_types_[i]; 289 } component_types() const290 const DataTypeVector component_types() const { 291 return value_component_types_; 292 } component_shapes() const293 const gtl::ArraySlice<TensorShape> component_shapes() const { 294 return value_component_shapes_; 295 } 296 ~Barrier()297 ~Barrier() override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 298 mutex_lock lock(mu_); 299 incomplete_.clear(); 300 ready_queue_->Unref(); 301 } 302 DebugString() const303 string DebugString() const override { return "A barrier"; } 304 305 protected: 306 template <typename T> InsertOneLocked(OpKernelContext * ctx,const Tensor & keys,const Tensor & values,const TensorShape & element_shape,int component_index,int i,std::vector<Tuple> * ready_tuples,bool * new_elements)307 Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, 308 const Tensor& values, const TensorShape& element_shape, 309 int component_index, int i, 310 std::vector<Tuple>* ready_tuples, bool* new_elements) 311 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 312 auto keys_vec = keys.flat<tstring>(); 313 auto values_matrix = values.flat_outer_dims<T>(); 314 315 TensorTuple* element_ptr; 316 if (closed_) { 317 element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i)); 318 if (element_ptr == nullptr) { 319 return errors::Cancelled( 320 "Barrier ", name_, 321 " is closed, but attempted to insert a brand new key: ", 322 keys_vec(i), 323 ". Pending enqueues cancelled: ", cancel_pending_enqueues_, 324 ". Insertion index: ", i, 325 ". Number of incomplete keys: ", incomplete_.size(), "."); 326 } 327 } else { 328 element_ptr = 329 >l::LookupOrInsert(&incomplete_, keys_vec(i), TensorTuple()); 330 } 331 TensorTuple& element = *element_ptr; 332 333 if (element.empty()) { // Never seen before key 334 // Added a new element, for keeping track of the insertion index 335 *new_elements = true; 336 337 // Initialize the incomplete tuple for a new key. 338 element.reserve(1 + num_components()); 339 340 // The first entry in element is the priority: the 341 // input_index_, so that tensors that entered the Barrier 342 // earlier have higher priority in the queue. 343 Tensor allocate_index_tensor; 344 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT64, TensorShape({}), 345 &allocate_index_tensor)); 346 347 Tensor index_tensor(DT_INT64, TensorShape({})); 348 allocate_index_tensor.scalar<int64_t>()() = input_index_; 349 element.push_back(allocate_index_tensor); 350 351 // The rest of the element stores uninitialized Tensors with 352 // the appropriate dtype. 353 for (int j = 0; j < num_components(); ++j) { 354 Tensor uninitialized(component_type(j)); 355 element.push_back(Tensor(uninitialized)); 356 } 357 } 358 const Tensor& component = element[1 + component_index]; 359 if (component.IsInitialized() && component.NumElements() > 0) { 360 return errors::InvalidArgument("Key ", keys_vec(i), 361 " already has a value for component ", 362 component_index, " in barrier ", name()); 363 } 364 365 // Extract the slice corresponding to the value from the value Tensor, 366 // and store it in the incomplete tuple at component_index. 367 Tensor next_element; 368 TF_RETURN_IF_ERROR( 369 ctx->allocate_temp(values.dtype(), element_shape, &next_element)); 370 element[1 + component_index] = next_element; 371 next_element.flat<T>() = values_matrix.template chip<0>(i); 372 373 // Check the components of the tuple to see if it has become complete 374 // (i.e. all of its components are initialized). If so, add it to the 375 // ready queue. 376 bool is_complete = true; 377 for (int j = 0; is_complete && j < element.size(); ++j) { 378 is_complete = element[j].IsInitialized() && element[j].NumElements() > 0; 379 } 380 if (is_complete) { 381 // Add tuple to the ready queue. A queue tuple has the index 382 // as the first element and the key as the second element, 383 // followed by the value components. 384 Tuple ready_tuple; 385 ready_tuple.reserve(2 + num_components()); // index, key, rest 386 // Build a tensor for the key. TODO(mrry): Something more efficient. 387 Tensor key; 388 TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_STRING, TensorShape({}), &key)); 389 ready_tuple.push_back(element[0]); // index 390 ready_tuple.push_back(key); // key 391 ready_tuple[1].scalar<tstring>()() = keys_vec(i); // set the key 392 for (int j = 1; j < num_components() + 1; ++j) { 393 ready_tuple.push_back(element[j]); 394 } 395 incomplete_.erase(incomplete_.find(keys_vec(i))); 396 TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple)); 397 ready_tuples->push_back(ready_tuple); 398 } 399 return OkStatus(); 400 } 401 CloseQueueLocked(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)402 void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, 403 const DoneCallback& callback) 404 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 405 // CloseQueueLocked may only be called with mu_ held. 406 if (!cancel_pending_enqueues && queue_closed_) { 407 callback(); 408 return; 409 } 410 if (cancel_pending_enqueues && queue_cancelled_) { 411 callback(); 412 return; 413 } 414 queue_closed_ = true; 415 if (cancel_pending_enqueues) queue_cancelled_ = true; 416 if (!ready_queue_->is_closed()) { 417 ready_queue_->Close(ctx, cancel_pending_enqueues, callback); 418 } 419 } 420 421 private: 422 typedef std::vector<Tensor> TensorTuple; 423 mutex mu_; 424 bool closed_ TF_GUARDED_BY(mu_); 425 bool queue_closed_ TF_GUARDED_BY(mu_); 426 bool queue_cancelled_ TF_GUARDED_BY(mu_); 427 bool cancel_pending_enqueues_ TF_GUARDED_BY(mu_); 428 const DataTypeVector value_component_types_; 429 const std::vector<TensorShape>& value_component_shapes_; 430 const string name_; 431 int64_t input_index_ TF_GUARDED_BY(mu_); 432 std::unordered_map<string, TensorTuple> incomplete_ TF_GUARDED_BY(mu_); 433 PriorityQueue* ready_queue_; 434 435 TF_DISALLOW_COPY_AND_ASSIGN(Barrier); 436 }; 437 438 class BarrierOp : public ResourceOpKernel<Barrier> { 439 public: BarrierOp(OpKernelConstruction * context)440 explicit BarrierOp(OpKernelConstruction* context) 441 : ResourceOpKernel(context) { 442 OP_REQUIRES_OK( 443 context, context->GetAttr("component_types", &value_component_types_)); 444 OP_REQUIRES_OK(context, 445 context->GetAttr("shapes", &value_component_shapes_)); 446 OP_REQUIRES(context, 447 value_component_shapes_.size() == value_component_types_.size(), 448 errors::InvalidArgument( 449 "All of the component shapes must be specified")); 450 451 int32_t value_capacity; 452 OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity)); 453 OP_REQUIRES(context, value_capacity == -1, 454 errors::InvalidArgument( 455 "Barrier only accepts capacity=-1. Feed the " 456 "inputs to your Barrier through a queue to enforce a " 457 "limited capacity.")); 458 } 459 460 private: CreateResource(Barrier ** barrier)461 Status CreateResource(Barrier** barrier) override 462 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 463 *barrier = new Barrier(value_component_types_, value_component_shapes_, 464 cinfo_.name()); 465 if (*barrier == nullptr) { 466 return errors::ResourceExhausted("Failed to allocate barrier"); 467 } 468 return (*barrier)->Initialize(); 469 } 470 VerifyResource(Barrier * barrier)471 Status VerifyResource(Barrier* barrier) override 472 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 473 if (barrier->component_types() != value_component_types_) { 474 return errors::InvalidArgument( 475 "Shared barrier '", cinfo_.name(), "' has component types ", 476 DataTypeSliceString(barrier->component_types()), 477 " but requested component types were ", 478 DataTypeSliceString(value_component_types_)); 479 } 480 if (barrier->component_shapes() != value_component_shapes_) { 481 return errors::InvalidArgument( 482 "Shared barrier '", cinfo_.name(), "' has component shapes ", 483 TensorShapeUtils::ShapeListString(barrier->component_shapes()), 484 " but requested component shapes were ", 485 TensorShapeUtils::ShapeListString(value_component_shapes_)); 486 } 487 return OkStatus(); 488 } 489 490 DataTypeVector value_component_types_; 491 std::vector<TensorShape> value_component_shapes_; 492 493 TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp); 494 }; 495 496 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp); 497 498 class BarrierOpKernel : public AsyncOpKernel { 499 public: BarrierOpKernel(OpKernelConstruction * context)500 explicit BarrierOpKernel(OpKernelConstruction* context) 501 : AsyncOpKernel(context) {} 502 ComputeAsync(OpKernelContext * ctx,DoneCallback callback)503 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { 504 Barrier* barrier = nullptr; 505 OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier), 506 callback); 507 ComputeAsync(ctx, barrier, [callback, barrier]() { 508 barrier->Unref(); 509 callback(); 510 }); 511 } 512 513 protected: 514 virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 515 DoneCallback callback) = 0; 516 }; 517 518 template <typename T> 519 class InsertManyOp : public BarrierOpKernel { 520 public: InsertManyOp(OpKernelConstruction * context)521 explicit InsertManyOp(OpKernelConstruction* context) 522 : BarrierOpKernel(context) { 523 OP_REQUIRES_OK(context, 524 context->GetAttr("component_index", &component_index_)); 525 } 526 527 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)528 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 529 DoneCallback callback) override { 530 OP_REQUIRES_ASYNC( 531 ctx, component_index_ < barrier->num_components(), 532 errors::InvalidArgument("The component ID is out of range ", 533 component_index_, " > num_components", 534 " (= ", barrier->num_components(), ")"), 535 callback); 536 OP_REQUIRES_OK_ASYNC( 537 ctx, 538 ctx->MatchSignature({DT_STRING_REF, DT_STRING, 539 barrier->component_type(component_index_)}, 540 {}), 541 callback); 542 543 const Tensor* keys; 544 const Tensor* values; 545 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback); 546 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback); 547 barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback); 548 } 549 550 private: 551 int component_index_; 552 TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp); 553 }; 554 555 #define REGISTER_INSERTMANY(T) \ 556 REGISTER_KERNEL_BUILDER( \ 557 Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 558 InsertManyOp<T>); 559 560 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY); 561 #undef REGISTER_INSERTMANY 562 563 class TakeManyOp : public BarrierOpKernel { 564 public: TakeManyOp(OpKernelConstruction * context)565 explicit TakeManyOp(OpKernelConstruction* context) 566 : BarrierOpKernel(context) { 567 OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); 568 // TODO(keveman): Enable timeout. 569 OP_REQUIRES(context, timeout_ == -1, 570 errors::InvalidArgument("Timeout not supported yet.")); 571 572 OP_REQUIRES_OK(context, 573 context->GetAttr("allow_small_batch", &allow_small_batch_)); 574 } 575 576 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)577 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 578 DoneCallback callback) override { 579 const Tensor* Tnum_elements; 580 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements), 581 callback); 582 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()), 583 errors::InvalidArgument("num_elements must be a scalar."), 584 callback); 585 const int32_t num_elements = Tnum_elements->scalar<int32>()(); 586 587 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32}; 588 // The first output is the insertion index, the second output is the key. 589 DataTypeVector expected_outputs = {DT_INT64, DT_STRING}; 590 for (DataType dt : barrier->component_types()) { 591 expected_outputs.push_back(dt); 592 } 593 OP_REQUIRES_OK_ASYNC( 594 ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback); 595 596 barrier->TryTakeMany( 597 num_elements, allow_small_batch_, timeout_, ctx, 598 [ctx, callback](const Tensor& indices, const Tensor& keys, 599 const Barrier::Tuple& values) { 600 if (!ctx->status().ok()) { 601 callback(); 602 return; 603 } 604 // At this point, indices, keys, and values 605 // have all been written to successfully. 606 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices), 607 callback); 608 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback); 609 OpOutputList values_output; 610 OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output), 611 callback); 612 for (size_t i = 0; i < values.size(); ++i) { 613 values_output.set(i, values[i]); 614 } 615 callback(); 616 }); 617 } 618 619 private: 620 int64_t timeout_; 621 bool allow_small_batch_; 622 TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp); 623 }; 624 625 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp); 626 627 class BarrierCloseOp : public BarrierOpKernel { 628 public: BarrierCloseOp(OpKernelConstruction * context)629 explicit BarrierCloseOp(OpKernelConstruction* context) 630 : BarrierOpKernel(context) { 631 OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", 632 &cancel_pending_enqueues_)); 633 } 634 635 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)636 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 637 DoneCallback callback) override { 638 barrier->Close(ctx, cancel_pending_enqueues_, callback); 639 } 640 641 private: 642 bool cancel_pending_enqueues_; 643 TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp); 644 }; 645 646 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU), 647 BarrierCloseOp); 648 649 class BarrierIncompleteSizeOp : public BarrierOpKernel { 650 public: BarrierIncompleteSizeOp(OpKernelConstruction * context)651 explicit BarrierIncompleteSizeOp(OpKernelConstruction* context) 652 : BarrierOpKernel(context) {} 653 654 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)655 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 656 DoneCallback callback) override { 657 Tensor* Tsize = nullptr; 658 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 659 callback); 660 Tsize->scalar<int32>().setConstant(barrier->incomplete_size()); 661 callback(); 662 } 663 }; 664 665 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU), 666 BarrierIncompleteSizeOp); 667 668 class BarrierReadySizeOp : public BarrierOpKernel { 669 public: BarrierReadySizeOp(OpKernelConstruction * context)670 explicit BarrierReadySizeOp(OpKernelConstruction* context) 671 : BarrierOpKernel(context) {} 672 673 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)674 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 675 DoneCallback callback) override { 676 Tensor* Tsize = nullptr; 677 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 678 callback); 679 Tsize->scalar<int32>().setConstant(barrier->ready_size()); 680 callback(); 681 } 682 }; 683 684 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU), 685 BarrierReadySizeOp); 686 687 } // namespace barrier 688 689 } // namespace tensorflow 690